1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package command
16
17import (
18	"crypto/tls"
19	"errors"
20	"fmt"
21	"io"
22	"io/ioutil"
23	"os"
24	"strings"
25	"time"
26
27	"github.com/bgentry/speakeasy"
28	"go.etcd.io/etcd/clientv3"
29	"go.etcd.io/etcd/pkg/flags"
30	"go.etcd.io/etcd/pkg/srv"
31	"go.etcd.io/etcd/pkg/transport"
32
33	"github.com/spf13/cobra"
34	"github.com/spf13/pflag"
35	"go.uber.org/zap"
36	"google.golang.org/grpc/grpclog"
37)
38
39// GlobalFlags are flags that defined globally
40// and are inherited to all sub-commands.
41type GlobalFlags struct {
42	Insecure              bool
43	InsecureSkipVerify    bool
44	InsecureDiscovery     bool
45	Endpoints             []string
46	DialTimeout           time.Duration
47	CommandTimeOut        time.Duration
48	KeepAliveTime         time.Duration
49	KeepAliveTimeout      time.Duration
50	DNSClusterServiceName string
51
52	TLS transport.TLSInfo
53
54	OutputFormat string
55	IsHex        bool
56
57	User     string
58	Password string
59
60	Debug bool
61}
62
63type secureCfg struct {
64	cert       string
65	key        string
66	cacert     string
67	serverName string
68
69	insecureTransport  bool
70	insecureSkipVerify bool
71}
72
73type authCfg struct {
74	username string
75	password string
76}
77
78type discoveryCfg struct {
79	domain      string
80	insecure    bool
81	serviceName string
82}
83
84var display printer = &simplePrinter{}
85
86func initDisplayFromCmd(cmd *cobra.Command) {
87	isHex, err := cmd.Flags().GetBool("hex")
88	if err != nil {
89		ExitWithError(ExitError, err)
90	}
91	outputType, err := cmd.Flags().GetString("write-out")
92	if err != nil {
93		ExitWithError(ExitError, err)
94	}
95	if display = NewPrinter(outputType, isHex); display == nil {
96		ExitWithError(ExitBadFeature, errors.New("unsupported output format"))
97	}
98}
99
100type clientConfig struct {
101	endpoints        []string
102	dialTimeout      time.Duration
103	keepAliveTime    time.Duration
104	keepAliveTimeout time.Duration
105	scfg             *secureCfg
106	acfg             *authCfg
107}
108
109type discardValue struct{}
110
111func (*discardValue) String() string   { return "" }
112func (*discardValue) Set(string) error { return nil }
113func (*discardValue) Type() string     { return "" }
114
115func clientConfigFromCmd(cmd *cobra.Command) *clientConfig {
116	lg, err := zap.NewProduction()
117	if err != nil {
118		ExitWithError(ExitError, err)
119	}
120	fs := cmd.InheritedFlags()
121	if strings.HasPrefix(cmd.Use, "watch") {
122		// silence "pkg/flags: unrecognized environment variable ETCDCTL_WATCH_KEY=foo" warnings
123		// silence "pkg/flags: unrecognized environment variable ETCDCTL_WATCH_RANGE_END=bar" warnings
124		fs.AddFlag(&pflag.Flag{Name: "watch-key", Value: &discardValue{}})
125		fs.AddFlag(&pflag.Flag{Name: "watch-range-end", Value: &discardValue{}})
126	}
127	flags.SetPflagsFromEnv(lg, "ETCDCTL", fs)
128
129	debug, err := cmd.Flags().GetBool("debug")
130	if err != nil {
131		ExitWithError(ExitError, err)
132	}
133	if debug {
134		clientv3.SetLogger(grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 4))
135		fs.VisitAll(func(f *pflag.Flag) {
136			fmt.Fprintf(os.Stderr, "%s=%v\n", flags.FlagToEnv("ETCDCTL", f.Name), f.Value)
137		})
138	} else {
139		// WARNING logs contain important information like TLS misconfirugation, but spams
140		// too many routine connection disconnects to turn on by default.
141		//
142		// See https://github.com/etcd-io/etcd/pull/9623 for background
143		clientv3.SetLogger(grpclog.NewLoggerV2(ioutil.Discard, ioutil.Discard, os.Stderr))
144	}
145
146	cfg := &clientConfig{}
147	cfg.endpoints, err = endpointsFromCmd(cmd)
148	if err != nil {
149		ExitWithError(ExitError, err)
150	}
151
152	cfg.dialTimeout = dialTimeoutFromCmd(cmd)
153	cfg.keepAliveTime = keepAliveTimeFromCmd(cmd)
154	cfg.keepAliveTimeout = keepAliveTimeoutFromCmd(cmd)
155
156	cfg.scfg = secureCfgFromCmd(cmd)
157	cfg.acfg = authCfgFromCmd(cmd)
158
159	initDisplayFromCmd(cmd)
160	return cfg
161}
162
163func mustClientCfgFromCmd(cmd *cobra.Command) *clientv3.Config {
164	cc := clientConfigFromCmd(cmd)
165	cfg, err := newClientCfg(cc.endpoints, cc.dialTimeout, cc.keepAliveTime, cc.keepAliveTimeout, cc.scfg, cc.acfg)
166	if err != nil {
167		ExitWithError(ExitBadArgs, err)
168	}
169	return cfg
170}
171
172func mustClientFromCmd(cmd *cobra.Command) *clientv3.Client {
173	cfg := clientConfigFromCmd(cmd)
174	return cfg.mustClient()
175}
176
177func (cc *clientConfig) mustClient() *clientv3.Client {
178	cfg, err := newClientCfg(cc.endpoints, cc.dialTimeout, cc.keepAliveTime, cc.keepAliveTimeout, cc.scfg, cc.acfg)
179	if err != nil {
180		ExitWithError(ExitBadArgs, err)
181	}
182
183	client, err := clientv3.New(*cfg)
184	if err != nil {
185		ExitWithError(ExitBadConnection, err)
186	}
187
188	return client
189}
190
191func newClientCfg(endpoints []string, dialTimeout, keepAliveTime, keepAliveTimeout time.Duration, scfg *secureCfg, acfg *authCfg) (*clientv3.Config, error) {
192	// set tls if any one tls option set
193	var cfgtls *transport.TLSInfo
194	tlsinfo := transport.TLSInfo{}
195	tlsinfo.Logger, _ = zap.NewProduction()
196	if scfg.cert != "" {
197		tlsinfo.CertFile = scfg.cert
198		cfgtls = &tlsinfo
199	}
200
201	if scfg.key != "" {
202		tlsinfo.KeyFile = scfg.key
203		cfgtls = &tlsinfo
204	}
205
206	if scfg.cacert != "" {
207		tlsinfo.TrustedCAFile = scfg.cacert
208		cfgtls = &tlsinfo
209	}
210
211	if scfg.serverName != "" {
212		tlsinfo.ServerName = scfg.serverName
213		cfgtls = &tlsinfo
214	}
215
216	cfg := &clientv3.Config{
217		Endpoints:            endpoints,
218		DialTimeout:          dialTimeout,
219		DialKeepAliveTime:    keepAliveTime,
220		DialKeepAliveTimeout: keepAliveTimeout,
221	}
222
223	if cfgtls != nil {
224		clientTLS, err := cfgtls.ClientConfig()
225		if err != nil {
226			return nil, err
227		}
228		cfg.TLS = clientTLS
229	}
230
231	// if key/cert is not given but user wants secure connection, we
232	// should still setup an empty tls configuration for gRPC to setup
233	// secure connection.
234	if cfg.TLS == nil && !scfg.insecureTransport {
235		cfg.TLS = &tls.Config{}
236	}
237
238	// If the user wants to skip TLS verification then we should set
239	// the InsecureSkipVerify flag in tls configuration.
240	if scfg.insecureSkipVerify && cfg.TLS != nil {
241		cfg.TLS.InsecureSkipVerify = true
242	}
243
244	if acfg != nil {
245		cfg.Username = acfg.username
246		cfg.Password = acfg.password
247	}
248
249	return cfg, nil
250}
251
252func argOrStdin(args []string, stdin io.Reader, i int) (string, error) {
253	if i < len(args) {
254		return args[i], nil
255	}
256	bytes, err := ioutil.ReadAll(stdin)
257	if string(bytes) == "" || err != nil {
258		return "", errors.New("no available argument and stdin")
259	}
260	return string(bytes), nil
261}
262
263func dialTimeoutFromCmd(cmd *cobra.Command) time.Duration {
264	dialTimeout, err := cmd.Flags().GetDuration("dial-timeout")
265	if err != nil {
266		ExitWithError(ExitError, err)
267	}
268	return dialTimeout
269}
270
271func keepAliveTimeFromCmd(cmd *cobra.Command) time.Duration {
272	keepAliveTime, err := cmd.Flags().GetDuration("keepalive-time")
273	if err != nil {
274		ExitWithError(ExitError, err)
275	}
276	return keepAliveTime
277}
278
279func keepAliveTimeoutFromCmd(cmd *cobra.Command) time.Duration {
280	keepAliveTimeout, err := cmd.Flags().GetDuration("keepalive-timeout")
281	if err != nil {
282		ExitWithError(ExitError, err)
283	}
284	return keepAliveTimeout
285}
286
287func secureCfgFromCmd(cmd *cobra.Command) *secureCfg {
288	cert, key, cacert := keyAndCertFromCmd(cmd)
289	insecureTr := insecureTransportFromCmd(cmd)
290	skipVerify := insecureSkipVerifyFromCmd(cmd)
291	discoveryCfg := discoveryCfgFromCmd(cmd)
292
293	if discoveryCfg.insecure {
294		discoveryCfg.domain = ""
295	}
296
297	return &secureCfg{
298		cert:       cert,
299		key:        key,
300		cacert:     cacert,
301		serverName: discoveryCfg.domain,
302
303		insecureTransport:  insecureTr,
304		insecureSkipVerify: skipVerify,
305	}
306}
307
308func insecureTransportFromCmd(cmd *cobra.Command) bool {
309	insecureTr, err := cmd.Flags().GetBool("insecure-transport")
310	if err != nil {
311		ExitWithError(ExitError, err)
312	}
313	return insecureTr
314}
315
316func insecureSkipVerifyFromCmd(cmd *cobra.Command) bool {
317	skipVerify, err := cmd.Flags().GetBool("insecure-skip-tls-verify")
318	if err != nil {
319		ExitWithError(ExitError, err)
320	}
321	return skipVerify
322}
323
324func keyAndCertFromCmd(cmd *cobra.Command) (cert, key, cacert string) {
325	var err error
326	if cert, err = cmd.Flags().GetString("cert"); err != nil {
327		ExitWithError(ExitBadArgs, err)
328	} else if cert == "" && cmd.Flags().Changed("cert") {
329		ExitWithError(ExitBadArgs, errors.New("empty string is passed to --cert option"))
330	}
331
332	if key, err = cmd.Flags().GetString("key"); err != nil {
333		ExitWithError(ExitBadArgs, err)
334	} else if key == "" && cmd.Flags().Changed("key") {
335		ExitWithError(ExitBadArgs, errors.New("empty string is passed to --key option"))
336	}
337
338	if cacert, err = cmd.Flags().GetString("cacert"); err != nil {
339		ExitWithError(ExitBadArgs, err)
340	} else if cacert == "" && cmd.Flags().Changed("cacert") {
341		ExitWithError(ExitBadArgs, errors.New("empty string is passed to --cacert option"))
342	}
343
344	return cert, key, cacert
345}
346
347func authCfgFromCmd(cmd *cobra.Command) *authCfg {
348	userFlag, err := cmd.Flags().GetString("user")
349	if err != nil {
350		ExitWithError(ExitBadArgs, err)
351	}
352	passwordFlag, err := cmd.Flags().GetString("password")
353	if err != nil {
354		ExitWithError(ExitBadArgs, err)
355	}
356
357	if userFlag == "" {
358		return nil
359	}
360
361	var cfg authCfg
362
363	if passwordFlag == "" {
364		splitted := strings.SplitN(userFlag, ":", 2)
365		if len(splitted) < 2 {
366			cfg.username = userFlag
367			cfg.password, err = speakeasy.Ask("Password: ")
368			if err != nil {
369				ExitWithError(ExitError, err)
370			}
371		} else {
372			cfg.username = splitted[0]
373			cfg.password = splitted[1]
374		}
375	} else {
376		cfg.username = userFlag
377		cfg.password = passwordFlag
378	}
379
380	return &cfg
381}
382
383func insecureDiscoveryFromCmd(cmd *cobra.Command) bool {
384	discovery, err := cmd.Flags().GetBool("insecure-discovery")
385	if err != nil {
386		ExitWithError(ExitError, err)
387	}
388	return discovery
389}
390
391func discoverySrvFromCmd(cmd *cobra.Command) string {
392	domainStr, err := cmd.Flags().GetString("discovery-srv")
393	if err != nil {
394		ExitWithError(ExitBadArgs, err)
395	}
396	return domainStr
397}
398
399func discoveryDNSClusterServiceNameFromCmd(cmd *cobra.Command) string {
400	serviceNameStr, err := cmd.Flags().GetString("discovery-srv-name")
401	if err != nil {
402		ExitWithError(ExitBadArgs, err)
403	}
404	return serviceNameStr
405}
406
407func discoveryCfgFromCmd(cmd *cobra.Command) *discoveryCfg {
408	return &discoveryCfg{
409		domain:      discoverySrvFromCmd(cmd),
410		insecure:    insecureDiscoveryFromCmd(cmd),
411		serviceName: discoveryDNSClusterServiceNameFromCmd(cmd),
412	}
413}
414
415func endpointsFromCmd(cmd *cobra.Command) ([]string, error) {
416	eps, err := endpointsFromFlagValue(cmd)
417	if err != nil {
418		return nil, err
419	}
420	// If domain discovery returns no endpoints, check endpoints flag
421	if len(eps) == 0 {
422		eps, err = cmd.Flags().GetStringSlice("endpoints")
423		if err == nil {
424			for i, ip := range eps {
425				eps[i] = strings.TrimSpace(ip)
426			}
427		}
428	}
429	return eps, err
430}
431
432func endpointsFromFlagValue(cmd *cobra.Command) ([]string, error) {
433	discoveryCfg := discoveryCfgFromCmd(cmd)
434
435	// If we still don't have domain discovery, return nothing
436	if discoveryCfg.domain == "" {
437		return []string{}, nil
438	}
439
440	srvs, err := srv.GetClient("etcd-client", discoveryCfg.domain, discoveryCfg.serviceName)
441	if err != nil {
442		return nil, err
443	}
444	eps := srvs.Endpoints
445	if discoveryCfg.insecure {
446		return eps, err
447	}
448	// strip insecure connections
449	ret := []string{}
450	for _, ep := range eps {
451		if strings.HasPrefix(ep, "http://") {
452			fmt.Fprintf(os.Stderr, "ignoring discovered insecure endpoint %q\n", ep)
453			continue
454		}
455		ret = append(ret, ep)
456	}
457	return ret, err
458}
459