1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package httpstream
18
19import (
20	"fmt"
21	"io"
22	"net/http"
23	"strings"
24	"time"
25)
26
27const (
28	HeaderConnection               = "Connection"
29	HeaderUpgrade                  = "Upgrade"
30	HeaderProtocolVersion          = "X-Stream-Protocol-Version"
31	HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions"
32)
33
34// NewStreamHandler defines a function that is called when a new Stream is
35// received. If no error is returned, the Stream is accepted; otherwise,
36// the stream is rejected. After the reply frame has been sent, replySent is closed.
37type NewStreamHandler func(stream Stream, replySent <-chan struct{}) error
38
39// NoOpNewStreamHandler is a stream handler that accepts a new stream and
40// performs no other logic.
41func NoOpNewStreamHandler(stream Stream, replySent <-chan struct{}) error { return nil }
42
43// Dialer knows how to open a streaming connection to a server.
44type Dialer interface {
45
46	// Dial opens a streaming connection to a server using one of the protocols
47	// specified (in order of most preferred to least preferred).
48	Dial(protocols ...string) (Connection, string, error)
49}
50
51// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
52// HTTP requests to support multiplexed bidirectional streams. After RoundTrip()
53// is invoked, if the upgrade is successful, clients may retrieve the upgraded
54// connection by calling UpgradeRoundTripper.Connection().
55type UpgradeRoundTripper interface {
56	http.RoundTripper
57	// NewConnection validates the response and creates a new Connection.
58	NewConnection(resp *http.Response) (Connection, error)
59}
60
61// ResponseUpgrader knows how to upgrade HTTP requests and responses to
62// add streaming support to them.
63type ResponseUpgrader interface {
64	// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
65	// streams. newStreamHandler will be called asynchronously whenever the
66	// other end of the upgraded connection creates a new stream.
67	UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
68}
69
70// Connection represents an upgraded HTTP connection.
71type Connection interface {
72	// CreateStream creates a new Stream with the supplied headers.
73	CreateStream(headers http.Header) (Stream, error)
74	// Close resets all streams and closes the connection.
75	Close() error
76	// CloseChan returns a channel that is closed when the underlying connection is closed.
77	CloseChan() <-chan bool
78	// SetIdleTimeout sets the amount of time the connection may remain idle before
79	// it is automatically closed.
80	SetIdleTimeout(timeout time.Duration)
81}
82
83// Stream represents a bidirectional communications channel that is part of an
84// upgraded connection.
85type Stream interface {
86	io.ReadWriteCloser
87	// Reset closes both directions of the stream, indicating that neither client
88	// or server can use it any more.
89	Reset() error
90	// Headers returns the headers used to create the stream.
91	Headers() http.Header
92	// Identifier returns the stream's ID.
93	Identifier() uint32
94}
95
96// IsUpgradeRequest returns true if the given request is a connection upgrade request
97func IsUpgradeRequest(req *http.Request) bool {
98	for _, h := range req.Header[http.CanonicalHeaderKey(HeaderConnection)] {
99		if strings.Contains(strings.ToLower(h), strings.ToLower(HeaderUpgrade)) {
100			return true
101		}
102	}
103	return false
104}
105
106func negotiateProtocol(clientProtocols, serverProtocols []string) string {
107	for i := range clientProtocols {
108		for j := range serverProtocols {
109			if clientProtocols[i] == serverProtocols[j] {
110				return clientProtocols[i]
111			}
112		}
113	}
114	return ""
115}
116
117// Handshake performs a subprotocol negotiation. If the client did request a
118// subprotocol, Handshake will select the first common value found in
119// serverProtocols. If a match is found, Handshake adds a response header
120// indicating the chosen subprotocol. If no match is found, HTTP forbidden is
121// returned, along with a response header containing the list of protocols the
122// server can accept.
123func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) {
124	clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)]
125	if len(clientProtocols) == 0 {
126		// Kube 1.0 clients didn't support subprotocol negotiation.
127		// TODO require clientProtocols once Kube 1.0 is no longer supported
128		return "", nil
129	}
130
131	if len(serverProtocols) == 0 {
132		// Kube 1.0 servers didn't support subprotocol negotiation. This is mainly for testing.
133		// TODO require serverProtocols once Kube 1.0 is no longer supported
134		return "", nil
135	}
136
137	negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
138	if len(negotiatedProtocol) == 0 {
139		for i := range serverProtocols {
140			w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i])
141		}
142		err := fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols)
143		http.Error(w, err.Error(), http.StatusForbidden)
144		return "", err
145	}
146
147	w.Header().Add(HeaderProtocolVersion, negotiatedProtocol)
148	return negotiatedProtocol, nil
149}
150