1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package command
16
17import (
18	"crypto/tls"
19	"errors"
20	"fmt"
21	"io"
22	"io/ioutil"
23	"os"
24	"strings"
25	"time"
26
27	"github.com/bgentry/speakeasy"
28	"go.etcd.io/etcd/client/pkg/v3/srv"
29	"go.etcd.io/etcd/client/pkg/v3/transport"
30	"go.etcd.io/etcd/client/v3"
31	"go.etcd.io/etcd/pkg/v3/cobrautl"
32	"go.etcd.io/etcd/pkg/v3/flags"
33
34	"github.com/spf13/cobra"
35	"github.com/spf13/pflag"
36	"go.uber.org/zap"
37	"google.golang.org/grpc/grpclog"
38)
39
40// GlobalFlags are flags that defined globally
41// and are inherited to all sub-commands.
42type GlobalFlags struct {
43	Insecure              bool
44	InsecureSkipVerify    bool
45	InsecureDiscovery     bool
46	Endpoints             []string
47	DialTimeout           time.Duration
48	CommandTimeOut        time.Duration
49	KeepAliveTime         time.Duration
50	KeepAliveTimeout      time.Duration
51	DNSClusterServiceName string
52
53	TLS transport.TLSInfo
54
55	OutputFormat string
56	IsHex        bool
57
58	User     string
59	Password string
60
61	Debug bool
62}
63
64type secureCfg struct {
65	cert       string
66	key        string
67	cacert     string
68	serverName string
69
70	insecureTransport  bool
71	insecureSkipVerify bool
72}
73
74type authCfg struct {
75	username string
76	password string
77}
78
79type discoveryCfg struct {
80	domain      string
81	insecure    bool
82	serviceName string
83}
84
85var display printer = &simplePrinter{}
86
87func initDisplayFromCmd(cmd *cobra.Command) {
88	isHex, err := cmd.Flags().GetBool("hex")
89	if err != nil {
90		cobrautl.ExitWithError(cobrautl.ExitError, err)
91	}
92	outputType, err := cmd.Flags().GetString("write-out")
93	if err != nil {
94		cobrautl.ExitWithError(cobrautl.ExitError, err)
95	}
96	if display = NewPrinter(outputType, isHex); display == nil {
97		cobrautl.ExitWithError(cobrautl.ExitBadFeature, errors.New("unsupported output format"))
98	}
99}
100
101type clientConfig struct {
102	endpoints        []string
103	dialTimeout      time.Duration
104	keepAliveTime    time.Duration
105	keepAliveTimeout time.Duration
106	scfg             *secureCfg
107	acfg             *authCfg
108}
109
110type discardValue struct{}
111
112func (*discardValue) String() string   { return "" }
113func (*discardValue) Set(string) error { return nil }
114func (*discardValue) Type() string     { return "" }
115
116func clientConfigFromCmd(cmd *cobra.Command) *clientConfig {
117	lg, err := zap.NewProduction()
118	if err != nil {
119		cobrautl.ExitWithError(cobrautl.ExitError, err)
120	}
121	fs := cmd.InheritedFlags()
122	if strings.HasPrefix(cmd.Use, "watch") {
123		// silence "pkg/flags: unrecognized environment variable ETCDCTL_WATCH_KEY=foo" warnings
124		// silence "pkg/flags: unrecognized environment variable ETCDCTL_WATCH_RANGE_END=bar" warnings
125		fs.AddFlag(&pflag.Flag{Name: "watch-key", Value: &discardValue{}})
126		fs.AddFlag(&pflag.Flag{Name: "watch-range-end", Value: &discardValue{}})
127	}
128	flags.SetPflagsFromEnv(lg, "ETCDCTL", fs)
129
130	debug, err := cmd.Flags().GetBool("debug")
131	if err != nil {
132		cobrautl.ExitWithError(cobrautl.ExitError, err)
133	}
134	if debug {
135		grpclog.SetLoggerV2(grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 4))
136		fs.VisitAll(func(f *pflag.Flag) {
137			fmt.Fprintf(os.Stderr, "%s=%v\n", flags.FlagToEnv("ETCDCTL", f.Name), f.Value)
138		})
139	} else {
140		// WARNING logs contain important information like TLS misconfirugation, but spams
141		// too many routine connection disconnects to turn on by default.
142		//
143		// See https://github.com/etcd-io/etcd/pull/9623 for background
144		grpclog.SetLoggerV2(grpclog.NewLoggerV2(ioutil.Discard, ioutil.Discard, os.Stderr))
145	}
146
147	cfg := &clientConfig{}
148	cfg.endpoints, err = endpointsFromCmd(cmd)
149	if err != nil {
150		cobrautl.ExitWithError(cobrautl.ExitError, err)
151	}
152
153	cfg.dialTimeout = dialTimeoutFromCmd(cmd)
154	cfg.keepAliveTime = keepAliveTimeFromCmd(cmd)
155	cfg.keepAliveTimeout = keepAliveTimeoutFromCmd(cmd)
156
157	cfg.scfg = secureCfgFromCmd(cmd)
158	cfg.acfg = authCfgFromCmd(cmd)
159
160	initDisplayFromCmd(cmd)
161	return cfg
162}
163
164func mustClientCfgFromCmd(cmd *cobra.Command) *clientv3.Config {
165	cc := clientConfigFromCmd(cmd)
166	cfg, err := newClientCfg(cc.endpoints, cc.dialTimeout, cc.keepAliveTime, cc.keepAliveTimeout, cc.scfg, cc.acfg)
167	if err != nil {
168		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
169	}
170	return cfg
171}
172
173func mustClientFromCmd(cmd *cobra.Command) *clientv3.Client {
174	cfg := clientConfigFromCmd(cmd)
175	return cfg.mustClient()
176}
177
178func (cc *clientConfig) mustClient() *clientv3.Client {
179	cfg, err := newClientCfg(cc.endpoints, cc.dialTimeout, cc.keepAliveTime, cc.keepAliveTimeout, cc.scfg, cc.acfg)
180	if err != nil {
181		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
182	}
183
184	client, err := clientv3.New(*cfg)
185	if err != nil {
186		cobrautl.ExitWithError(cobrautl.ExitBadConnection, err)
187	}
188
189	return client
190}
191
192func newClientCfg(endpoints []string, dialTimeout, keepAliveTime, keepAliveTimeout time.Duration, scfg *secureCfg, acfg *authCfg) (*clientv3.Config, error) {
193	// set tls if any one tls option set
194	var cfgtls *transport.TLSInfo
195	tlsinfo := transport.TLSInfo{}
196	tlsinfo.Logger, _ = zap.NewProduction()
197	if scfg.cert != "" {
198		tlsinfo.CertFile = scfg.cert
199		cfgtls = &tlsinfo
200	}
201
202	if scfg.key != "" {
203		tlsinfo.KeyFile = scfg.key
204		cfgtls = &tlsinfo
205	}
206
207	if scfg.cacert != "" {
208		tlsinfo.TrustedCAFile = scfg.cacert
209		cfgtls = &tlsinfo
210	}
211
212	if scfg.serverName != "" {
213		tlsinfo.ServerName = scfg.serverName
214		cfgtls = &tlsinfo
215	}
216
217	cfg := &clientv3.Config{
218		Endpoints:            endpoints,
219		DialTimeout:          dialTimeout,
220		DialKeepAliveTime:    keepAliveTime,
221		DialKeepAliveTimeout: keepAliveTimeout,
222	}
223
224	if cfgtls != nil {
225		clientTLS, err := cfgtls.ClientConfig()
226		if err != nil {
227			return nil, err
228		}
229		cfg.TLS = clientTLS
230	}
231
232	// if key/cert is not given but user wants secure connection, we
233	// should still setup an empty tls configuration for gRPC to setup
234	// secure connection.
235	if cfg.TLS == nil && !scfg.insecureTransport {
236		cfg.TLS = &tls.Config{}
237	}
238
239	// If the user wants to skip TLS verification then we should set
240	// the InsecureSkipVerify flag in tls configuration.
241	if scfg.insecureSkipVerify && cfg.TLS != nil {
242		cfg.TLS.InsecureSkipVerify = true
243	}
244
245	if acfg != nil {
246		cfg.Username = acfg.username
247		cfg.Password = acfg.password
248	}
249
250	return cfg, nil
251}
252
253func argOrStdin(args []string, stdin io.Reader, i int) (string, error) {
254	if i < len(args) {
255		return args[i], nil
256	}
257	bytes, err := ioutil.ReadAll(stdin)
258	if string(bytes) == "" || err != nil {
259		return "", errors.New("no available argument and stdin")
260	}
261	return string(bytes), nil
262}
263
264func dialTimeoutFromCmd(cmd *cobra.Command) time.Duration {
265	dialTimeout, err := cmd.Flags().GetDuration("dial-timeout")
266	if err != nil {
267		cobrautl.ExitWithError(cobrautl.ExitError, err)
268	}
269	return dialTimeout
270}
271
272func keepAliveTimeFromCmd(cmd *cobra.Command) time.Duration {
273	keepAliveTime, err := cmd.Flags().GetDuration("keepalive-time")
274	if err != nil {
275		cobrautl.ExitWithError(cobrautl.ExitError, err)
276	}
277	return keepAliveTime
278}
279
280func keepAliveTimeoutFromCmd(cmd *cobra.Command) time.Duration {
281	keepAliveTimeout, err := cmd.Flags().GetDuration("keepalive-timeout")
282	if err != nil {
283		cobrautl.ExitWithError(cobrautl.ExitError, err)
284	}
285	return keepAliveTimeout
286}
287
288func secureCfgFromCmd(cmd *cobra.Command) *secureCfg {
289	cert, key, cacert := keyAndCertFromCmd(cmd)
290	insecureTr := insecureTransportFromCmd(cmd)
291	skipVerify := insecureSkipVerifyFromCmd(cmd)
292	discoveryCfg := discoveryCfgFromCmd(cmd)
293
294	if discoveryCfg.insecure {
295		discoveryCfg.domain = ""
296	}
297
298	return &secureCfg{
299		cert:       cert,
300		key:        key,
301		cacert:     cacert,
302		serverName: discoveryCfg.domain,
303
304		insecureTransport:  insecureTr,
305		insecureSkipVerify: skipVerify,
306	}
307}
308
309func insecureTransportFromCmd(cmd *cobra.Command) bool {
310	insecureTr, err := cmd.Flags().GetBool("insecure-transport")
311	if err != nil {
312		cobrautl.ExitWithError(cobrautl.ExitError, err)
313	}
314	return insecureTr
315}
316
317func insecureSkipVerifyFromCmd(cmd *cobra.Command) bool {
318	skipVerify, err := cmd.Flags().GetBool("insecure-skip-tls-verify")
319	if err != nil {
320		cobrautl.ExitWithError(cobrautl.ExitError, err)
321	}
322	return skipVerify
323}
324
325func keyAndCertFromCmd(cmd *cobra.Command) (cert, key, cacert string) {
326	var err error
327	if cert, err = cmd.Flags().GetString("cert"); err != nil {
328		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
329	} else if cert == "" && cmd.Flags().Changed("cert") {
330		cobrautl.ExitWithError(cobrautl.ExitBadArgs, errors.New("empty string is passed to --cert option"))
331	}
332
333	if key, err = cmd.Flags().GetString("key"); err != nil {
334		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
335	} else if key == "" && cmd.Flags().Changed("key") {
336		cobrautl.ExitWithError(cobrautl.ExitBadArgs, errors.New("empty string is passed to --key option"))
337	}
338
339	if cacert, err = cmd.Flags().GetString("cacert"); err != nil {
340		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
341	} else if cacert == "" && cmd.Flags().Changed("cacert") {
342		cobrautl.ExitWithError(cobrautl.ExitBadArgs, errors.New("empty string is passed to --cacert option"))
343	}
344
345	return cert, key, cacert
346}
347
348func authCfgFromCmd(cmd *cobra.Command) *authCfg {
349	userFlag, err := cmd.Flags().GetString("user")
350	if err != nil {
351		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
352	}
353	passwordFlag, err := cmd.Flags().GetString("password")
354	if err != nil {
355		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
356	}
357
358	if userFlag == "" {
359		return nil
360	}
361
362	var cfg authCfg
363
364	if passwordFlag == "" {
365		splitted := strings.SplitN(userFlag, ":", 2)
366		if len(splitted) < 2 {
367			cfg.username = userFlag
368			cfg.password, err = speakeasy.Ask("Password: ")
369			if err != nil {
370				cobrautl.ExitWithError(cobrautl.ExitError, err)
371			}
372		} else {
373			cfg.username = splitted[0]
374			cfg.password = splitted[1]
375		}
376	} else {
377		cfg.username = userFlag
378		cfg.password = passwordFlag
379	}
380
381	return &cfg
382}
383
384func insecureDiscoveryFromCmd(cmd *cobra.Command) bool {
385	discovery, err := cmd.Flags().GetBool("insecure-discovery")
386	if err != nil {
387		cobrautl.ExitWithError(cobrautl.ExitError, err)
388	}
389	return discovery
390}
391
392func discoverySrvFromCmd(cmd *cobra.Command) string {
393	domainStr, err := cmd.Flags().GetString("discovery-srv")
394	if err != nil {
395		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
396	}
397	return domainStr
398}
399
400func discoveryDNSClusterServiceNameFromCmd(cmd *cobra.Command) string {
401	serviceNameStr, err := cmd.Flags().GetString("discovery-srv-name")
402	if err != nil {
403		cobrautl.ExitWithError(cobrautl.ExitBadArgs, err)
404	}
405	return serviceNameStr
406}
407
408func discoveryCfgFromCmd(cmd *cobra.Command) *discoveryCfg {
409	return &discoveryCfg{
410		domain:      discoverySrvFromCmd(cmd),
411		insecure:    insecureDiscoveryFromCmd(cmd),
412		serviceName: discoveryDNSClusterServiceNameFromCmd(cmd),
413	}
414}
415
416func endpointsFromCmd(cmd *cobra.Command) ([]string, error) {
417	eps, err := endpointsFromFlagValue(cmd)
418	if err != nil {
419		return nil, err
420	}
421	// If domain discovery returns no endpoints, check endpoints flag
422	if len(eps) == 0 {
423		eps, err = cmd.Flags().GetStringSlice("endpoints")
424		if err == nil {
425			for i, ip := range eps {
426				eps[i] = strings.TrimSpace(ip)
427			}
428		}
429	}
430	return eps, err
431}
432
433func endpointsFromFlagValue(cmd *cobra.Command) ([]string, error) {
434	discoveryCfg := discoveryCfgFromCmd(cmd)
435
436	// If we still don't have domain discovery, return nothing
437	if discoveryCfg.domain == "" {
438		return []string{}, nil
439	}
440
441	srvs, err := srv.GetClient("etcd-client", discoveryCfg.domain, discoveryCfg.serviceName)
442	if err != nil {
443		return nil, err
444	}
445	eps := srvs.Endpoints
446	if discoveryCfg.insecure {
447		return eps, err
448	}
449	// strip insecure connections
450	ret := []string{}
451	for _, ep := range eps {
452		if strings.HasPrefix(ep, "http://") {
453			fmt.Fprintf(os.Stderr, "ignoring discovered insecure endpoint %q\n", ep)
454			continue
455		}
456		ret = append(ret, ep)
457	}
458	return ret, err
459}
460