1/*
2Copyright 2019 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package connect
18
19import (
20	"fmt"
21	"net"
22	"os"
23	"strings"
24	"syscall"
25	"time"
26
27	"github.com/ishidawataru/sctp"
28	"github.com/spf13/cobra"
29)
30
31// CmdConnect is used by agnhost Cobra.
32var CmdConnect = &cobra.Command{
33	Use:   "connect [host:port]",
34	Short: "Attempts a TCP, UDP or SCTP connection and returns useful errors",
35	Long: `Tries to open a TCP, UDP or SCTP connection to the given host and port. On error it prints an error message prefixed with a specific fixed string that test cases can check for:
36
37* UNKNOWN - Generic/unknown (non-network) error (eg, bad arguments)
38* TIMEOUT - The connection attempt timed out
39* DNS - An error in DNS resolution
40* REFUSED - Connection refused
41* OTHER - Other networking error (eg, "no route to host")`,
42	Args: cobra.ExactArgs(1),
43	Run:  main,
44}
45
46var (
47	timeout  time.Duration
48	protocol string
49	udpData  string
50)
51
52func init() {
53	CmdConnect.Flags().DurationVar(&timeout, "timeout", time.Duration(0), "Maximum time before returning an error")
54	CmdConnect.Flags().StringVar(&protocol, "protocol", "tcp", "The protocol to use to perform the connection, can be tcp, udp or sctp")
55	CmdConnect.Flags().StringVar(&udpData, "udp-data", "hostname", "The UDP payload send to the server")
56}
57
58func main(cmd *cobra.Command, args []string) {
59	dest := args[0]
60	switch protocol {
61	case "", "tcp":
62		connectTCP(dest, timeout)
63	case "udp":
64		connectUDP(dest, timeout, udpData)
65	case "sctp":
66		connectSCTP(dest, timeout)
67	default:
68		fmt.Fprint(os.Stderr, "Unsupported protocol\n", protocol)
69		os.Exit(1)
70	}
71}
72
73func connectTCP(dest string, timeout time.Duration) {
74	// Redundantly parse and resolve the destination so we can return the correct
75	// errors if there's a problem.
76	if _, _, err := net.SplitHostPort(dest); err != nil {
77		fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err)
78		os.Exit(1)
79	}
80	if _, err := net.ResolveTCPAddr("tcp", dest); err != nil {
81		fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
82		os.Exit(1)
83	}
84
85	conn, err := net.DialTimeout("tcp", dest, timeout)
86	if err == nil {
87		conn.Close()
88		os.Exit(0)
89	}
90	if opErr, ok := err.(*net.OpError); ok {
91		if opErr.Timeout() {
92			fmt.Fprintf(os.Stderr, "TIMEOUT\n")
93			os.Exit(1)
94		} else if syscallErr, ok := opErr.Err.(*os.SyscallError); ok {
95			if syscallErr.Err == syscall.ECONNREFUSED {
96				fmt.Fprintf(os.Stderr, "REFUSED\n")
97				os.Exit(1)
98			}
99		}
100	}
101
102	fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
103	os.Exit(1)
104}
105
106func connectSCTP(dest string, timeout time.Duration) {
107	addr, err := sctp.ResolveSCTPAddr("sctp", dest)
108	if err != nil {
109		fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
110		os.Exit(1)
111	}
112
113	timeoutCh := time.After(timeout)
114	errCh := make(chan (error))
115
116	go func() {
117		conn, err := sctp.DialSCTP("sctp", nil, addr)
118		if err == nil {
119			conn.Close()
120		}
121		errCh <- err
122	}()
123
124	select {
125	case err := <-errCh:
126		if err != nil {
127			fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
128			os.Exit(1)
129		}
130	case <-timeoutCh:
131		fmt.Fprint(os.Stderr, "TIMEOUT\n")
132		os.Exit(1)
133	}
134}
135
136func connectUDP(dest string, timeout time.Duration, data string) {
137	var (
138		readBytes int
139		buf       = make([]byte, 1024)
140	)
141
142	if _, err := net.ResolveUDPAddr("udp", dest); err != nil {
143		fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
144		os.Exit(1)
145	}
146
147	conn, err := net.Dial("udp", dest)
148	if err != nil {
149		fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
150		os.Exit(1)
151	}
152
153	if timeout > 0 {
154		if err = conn.SetDeadline(time.Now().Add(timeout)); err != nil {
155			fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
156			os.Exit(1)
157		}
158	}
159
160	if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", data))); err != nil {
161		parseUDPErrorAndExit(err)
162	}
163
164	if readBytes, err = conn.Read(buf); err != nil {
165		parseUDPErrorAndExit(err)
166	}
167
168	// ensure the response from UDP server
169	if readBytes == 0 {
170		fmt.Fprintf(os.Stderr, "OTHER: No data received from the server. Cannot guarantee the server received the request.\n")
171		os.Exit(1)
172	}
173}
174
175func parseUDPErrorAndExit(err error) {
176	neterr, ok := err.(net.Error)
177	if ok && neterr.Timeout() {
178		fmt.Fprintf(os.Stderr, "TIMEOUT: %v\n", err)
179	} else if strings.Contains(err.Error(), "connection refused") {
180		fmt.Fprintf(os.Stderr, "REFUSED: %v\n", err)
181	} else {
182		fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err)
183	}
184	os.Exit(1)
185}
186