1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package v2discovery provides an implementation of the cluster discovery that
16// is used by etcd with v2 client.
17package v2discovery
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"math"
24	"net/http"
25	"net/url"
26	"path"
27	"sort"
28	"strconv"
29	"strings"
30	"time"
31
32	"go.etcd.io/etcd/client"
33	"go.etcd.io/etcd/pkg/transport"
34	"go.etcd.io/etcd/pkg/types"
35
36	"github.com/jonboulle/clockwork"
37	"go.uber.org/zap"
38)
39
40var (
41	ErrInvalidURL           = errors.New("discovery: invalid URL")
42	ErrBadSizeKey           = errors.New("discovery: size key is bad")
43	ErrSizeNotFound         = errors.New("discovery: size key not found")
44	ErrTokenNotFound        = errors.New("discovery: token not found")
45	ErrDuplicateID          = errors.New("discovery: found duplicate id")
46	ErrDuplicateName        = errors.New("discovery: found duplicate name")
47	ErrFullCluster          = errors.New("discovery: cluster is full")
48	ErrTooManyRetries       = errors.New("discovery: too many retries")
49	ErrBadDiscoveryEndpoint = errors.New("discovery: bad discovery endpoint")
50)
51
52var (
53	// Number of retries discovery will attempt before giving up and erroring out.
54	nRetries             = uint(math.MaxUint32)
55	maxExpoentialRetries = uint(8)
56)
57
58// JoinCluster will connect to the discovery service at the given url, and
59// register the server represented by the given id and config to the cluster
60func JoinCluster(lg *zap.Logger, durl, dproxyurl string, id types.ID, config string) (string, error) {
61	d, err := newDiscovery(lg, durl, dproxyurl, id)
62	if err != nil {
63		return "", err
64	}
65	return d.joinCluster(config)
66}
67
68// GetCluster will connect to the discovery service at the given url and
69// retrieve a string describing the cluster
70func GetCluster(lg *zap.Logger, durl, dproxyurl string) (string, error) {
71	d, err := newDiscovery(lg, durl, dproxyurl, 0)
72	if err != nil {
73		return "", err
74	}
75	return d.getCluster()
76}
77
78type discovery struct {
79	lg      *zap.Logger
80	cluster string
81	id      types.ID
82	c       client.KeysAPI
83	retries uint
84	url     *url.URL
85
86	clock clockwork.Clock
87}
88
89// newProxyFunc builds a proxy function from the given string, which should
90// represent a URL that can be used as a proxy. It performs basic
91// sanitization of the URL and returns any error encountered.
92func newProxyFunc(lg *zap.Logger, proxy string) (func(*http.Request) (*url.URL, error), error) {
93	if lg == nil {
94		lg = zap.NewNop()
95	}
96	if proxy == "" {
97		return nil, nil
98	}
99	// Do a small amount of URL sanitization to help the user
100	// Derived from net/http.ProxyFromEnvironment
101	proxyURL, err := url.Parse(proxy)
102	if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") {
103		// proxy was bogus. Try prepending "http://" to it and
104		// see if that parses correctly. If not, we ignore the
105		// error and complain about the original one
106		var err2 error
107		proxyURL, err2 = url.Parse("http://" + proxy)
108		if err2 == nil {
109			err = nil
110		}
111	}
112	if err != nil {
113		return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err)
114	}
115
116	lg.Info("running proxy with discovery", zap.String("proxy-url", proxyURL.String()))
117	return http.ProxyURL(proxyURL), nil
118}
119
120func newDiscovery(lg *zap.Logger, durl, dproxyurl string, id types.ID) (*discovery, error) {
121	if lg == nil {
122		lg = zap.NewNop()
123	}
124	u, err := url.Parse(durl)
125	if err != nil {
126		return nil, err
127	}
128	token := u.Path
129	u.Path = ""
130	pf, err := newProxyFunc(lg, dproxyurl)
131	if err != nil {
132		return nil, err
133	}
134
135	// TODO: add ResponseHeaderTimeout back when watch on discovery service writes header early
136	tr, err := transport.NewTransport(transport.TLSInfo{}, 30*time.Second)
137	if err != nil {
138		return nil, err
139	}
140	tr.Proxy = pf
141	cfg := client.Config{
142		Transport: tr,
143		Endpoints: []string{u.String()},
144	}
145	c, err := client.New(cfg)
146	if err != nil {
147		return nil, err
148	}
149	dc := client.NewKeysAPIWithPrefix(c, "")
150	return &discovery{
151		lg:      lg,
152		cluster: token,
153		c:       dc,
154		id:      id,
155		url:     u,
156		clock:   clockwork.NewRealClock(),
157	}, nil
158}
159
160func (d *discovery) joinCluster(config string) (string, error) {
161	// fast path: if the cluster is full, return the error
162	// do not need to register to the cluster in this case.
163	if _, _, _, err := d.checkCluster(); err != nil {
164		return "", err
165	}
166
167	if err := d.createSelf(config); err != nil {
168		// Fails, even on a timeout, if createSelf times out.
169		// TODO(barakmich): Retrying the same node might want to succeed here
170		// (ie, createSelf should be idempotent for discovery).
171		return "", err
172	}
173
174	nodes, size, index, err := d.checkCluster()
175	if err != nil {
176		return "", err
177	}
178
179	all, err := d.waitNodes(nodes, size, index)
180	if err != nil {
181		return "", err
182	}
183
184	return nodesToCluster(all, size)
185}
186
187func (d *discovery) getCluster() (string, error) {
188	nodes, size, index, err := d.checkCluster()
189	if err != nil {
190		if err == ErrFullCluster {
191			return nodesToCluster(nodes, size)
192		}
193		return "", err
194	}
195
196	all, err := d.waitNodes(nodes, size, index)
197	if err != nil {
198		return "", err
199	}
200	return nodesToCluster(all, size)
201}
202
203func (d *discovery) createSelf(contents string) error {
204	ctx, cancel := context.WithTimeout(context.Background(), client.DefaultRequestTimeout)
205	resp, err := d.c.Create(ctx, d.selfKey(), contents)
206	cancel()
207	if err != nil {
208		if eerr, ok := err.(client.Error); ok && eerr.Code == client.ErrorCodeNodeExist {
209			return ErrDuplicateID
210		}
211		return err
212	}
213
214	// ensure self appears on the server we connected to
215	w := d.c.Watcher(d.selfKey(), &client.WatcherOptions{AfterIndex: resp.Node.CreatedIndex - 1})
216	_, err = w.Next(context.Background())
217	return err
218}
219
220func (d *discovery) checkCluster() ([]*client.Node, uint64, uint64, error) {
221	configKey := path.Join("/", d.cluster, "_config")
222	ctx, cancel := context.WithTimeout(context.Background(), client.DefaultRequestTimeout)
223	// find cluster size
224	resp, err := d.c.Get(ctx, path.Join(configKey, "size"), nil)
225	cancel()
226	if err != nil {
227		if eerr, ok := err.(*client.Error); ok && eerr.Code == client.ErrorCodeKeyNotFound {
228			return nil, 0, 0, ErrSizeNotFound
229		}
230		if err == client.ErrInvalidJSON {
231			return nil, 0, 0, ErrBadDiscoveryEndpoint
232		}
233		if ce, ok := err.(*client.ClusterError); ok {
234			d.lg.Warn(
235				"failed to get from discovery server",
236				zap.String("discovery-url", d.url.String()),
237				zap.String("path", path.Join(configKey, "size")),
238				zap.Error(err),
239				zap.String("err-detail", ce.Detail()),
240			)
241			return d.checkClusterRetry()
242		}
243		return nil, 0, 0, err
244	}
245	size, err := strconv.ParseUint(resp.Node.Value, 10, 0)
246	if err != nil {
247		return nil, 0, 0, ErrBadSizeKey
248	}
249
250	ctx, cancel = context.WithTimeout(context.Background(), client.DefaultRequestTimeout)
251	resp, err = d.c.Get(ctx, d.cluster, nil)
252	cancel()
253	if err != nil {
254		if ce, ok := err.(*client.ClusterError); ok {
255			d.lg.Warn(
256				"failed to get from discovery server",
257				zap.String("discovery-url", d.url.String()),
258				zap.String("path", d.cluster),
259				zap.Error(err),
260				zap.String("err-detail", ce.Detail()),
261			)
262			return d.checkClusterRetry()
263		}
264		return nil, 0, 0, err
265	}
266	var nodes []*client.Node
267	// append non-config keys to nodes
268	for _, n := range resp.Node.Nodes {
269		if path.Base(n.Key) != path.Base(configKey) {
270			nodes = append(nodes, n)
271		}
272	}
273
274	snodes := sortableNodes{nodes}
275	sort.Sort(snodes)
276
277	// find self position
278	for i := range nodes {
279		if path.Base(nodes[i].Key) == path.Base(d.selfKey()) {
280			break
281		}
282		if uint64(i) >= size-1 {
283			return nodes[:size], size, resp.Index, ErrFullCluster
284		}
285	}
286	return nodes, size, resp.Index, nil
287}
288
289func (d *discovery) logAndBackoffForRetry(step string) {
290	d.retries++
291	// logAndBackoffForRetry stops exponential backoff when the retries are more than maxExpoentialRetries and is set to a constant backoff afterward.
292	retries := d.retries
293	if retries > maxExpoentialRetries {
294		retries = maxExpoentialRetries
295	}
296	retryTimeInSecond := time.Duration(0x1<<retries) * time.Second
297	d.lg.Info(
298		"retry connecting to discovery service",
299		zap.String("url", d.url.String()),
300		zap.String("reason", step),
301		zap.Duration("backoff", retryTimeInSecond),
302	)
303	d.clock.Sleep(retryTimeInSecond)
304}
305
306func (d *discovery) checkClusterRetry() ([]*client.Node, uint64, uint64, error) {
307	if d.retries < nRetries {
308		d.logAndBackoffForRetry("cluster status check")
309		return d.checkCluster()
310	}
311	return nil, 0, 0, ErrTooManyRetries
312}
313
314func (d *discovery) waitNodesRetry() ([]*client.Node, error) {
315	if d.retries < nRetries {
316		d.logAndBackoffForRetry("waiting for other nodes")
317		nodes, n, index, err := d.checkCluster()
318		if err != nil {
319			return nil, err
320		}
321		return d.waitNodes(nodes, n, index)
322	}
323	return nil, ErrTooManyRetries
324}
325
326func (d *discovery) waitNodes(nodes []*client.Node, size uint64, index uint64) ([]*client.Node, error) {
327	if uint64(len(nodes)) > size {
328		nodes = nodes[:size]
329	}
330	// watch from the next index
331	w := d.c.Watcher(d.cluster, &client.WatcherOptions{AfterIndex: index, Recursive: true})
332	all := make([]*client.Node, len(nodes))
333	copy(all, nodes)
334	for _, n := range all {
335		if path.Base(n.Key) == path.Base(d.selfKey()) {
336			d.lg.Info(
337				"found self from discovery server",
338				zap.String("discovery-url", d.url.String()),
339				zap.String("self", path.Base(d.selfKey())),
340			)
341		} else {
342			d.lg.Info(
343				"found peer from discovery server",
344				zap.String("discovery-url", d.url.String()),
345				zap.String("peer", path.Base(n.Key)),
346			)
347		}
348	}
349
350	// wait for others
351	for uint64(len(all)) < size {
352		d.lg.Info(
353			"found peers from discovery server; waiting for more",
354			zap.String("discovery-url", d.url.String()),
355			zap.Int("found-peers", len(all)),
356			zap.Int("needed-peers", int(size-uint64(len(all)))),
357		)
358		resp, err := w.Next(context.Background())
359		if err != nil {
360			if ce, ok := err.(*client.ClusterError); ok {
361				d.lg.Warn(
362					"error while waiting for peers",
363					zap.String("discovery-url", d.url.String()),
364					zap.Error(err),
365					zap.String("err-detail", ce.Detail()),
366				)
367				return d.waitNodesRetry()
368			}
369			return nil, err
370		}
371		d.lg.Info(
372			"found peer from discovery server",
373			zap.String("discovery-url", d.url.String()),
374			zap.String("peer", path.Base(resp.Node.Key)),
375		)
376		all = append(all, resp.Node)
377	}
378	d.lg.Info(
379		"found all needed peers from discovery server",
380		zap.String("discovery-url", d.url.String()),
381		zap.Int("found-peers", len(all)),
382	)
383	return all, nil
384}
385
386func (d *discovery) selfKey() string {
387	return path.Join("/", d.cluster, d.id.String())
388}
389
390func nodesToCluster(ns []*client.Node, size uint64) (string, error) {
391	s := make([]string, len(ns))
392	for i, n := range ns {
393		s[i] = n.Value
394	}
395	us := strings.Join(s, ",")
396	m, err := types.NewURLsMap(us)
397	if err != nil {
398		return us, ErrInvalidURL
399	}
400	if uint64(m.Len()) != size {
401		return us, ErrDuplicateName
402	}
403	return us, nil
404}
405
406type sortableNodes struct{ Nodes []*client.Node }
407
408func (ns sortableNodes) Len() int { return len(ns.Nodes) }
409func (ns sortableNodes) Less(i, j int) bool {
410	return ns.Nodes[i].CreatedIndex < ns.Nodes[j].CreatedIndex
411}
412func (ns sortableNodes) Swap(i, j int) { ns.Nodes[i], ns.Nodes[j] = ns.Nodes[j], ns.Nodes[i] }
413