1/* 2Copyright 2019 The Kubernetes Authors. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package connect 18 19import ( 20 "fmt" 21 "net" 22 "os" 23 "strings" 24 "syscall" 25 "time" 26 27 "github.com/ishidawataru/sctp" 28 "github.com/spf13/cobra" 29) 30 31// CmdConnect is used by agnhost Cobra. 32var CmdConnect = &cobra.Command{ 33 Use: "connect [host:port]", 34 Short: "Attempts a TCP, UDP or SCTP connection and returns useful errors", 35 Long: `Tries to open a TCP, UDP or SCTP connection to the given host and port. On error it prints an error message prefixed with a specific fixed string that test cases can check for: 36 37* UNKNOWN - Generic/unknown (non-network) error (eg, bad arguments) 38* TIMEOUT - The connection attempt timed out 39* DNS - An error in DNS resolution 40* REFUSED - Connection refused 41* OTHER - Other networking error (eg, "no route to host")`, 42 Args: cobra.ExactArgs(1), 43 Run: main, 44} 45 46var ( 47 timeout time.Duration 48 protocol string 49 udpData string 50) 51 52func init() { 53 CmdConnect.Flags().DurationVar(&timeout, "timeout", time.Duration(0), "Maximum time before returning an error") 54 CmdConnect.Flags().StringVar(&protocol, "protocol", "tcp", "The protocol to use to perform the connection, can be tcp, udp or sctp") 55 CmdConnect.Flags().StringVar(&udpData, "udp-data", "hostname", "The UDP payload send to the server") 56} 57 58func main(cmd *cobra.Command, args []string) { 59 dest := args[0] 60 switch protocol { 61 case "", "tcp": 62 connectTCP(dest, timeout) 63 case "udp": 64 connectUDP(dest, timeout, udpData) 65 case "sctp": 66 connectSCTP(dest, timeout) 67 default: 68 fmt.Fprint(os.Stderr, "Unsupported protocol\n", protocol) 69 os.Exit(1) 70 } 71} 72 73func connectTCP(dest string, timeout time.Duration) { 74 // Redundantly parse and resolve the destination so we can return the correct 75 // errors if there's a problem. 76 if _, _, err := net.SplitHostPort(dest); err != nil { 77 fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err) 78 os.Exit(1) 79 } 80 if _, err := net.ResolveTCPAddr("tcp", dest); err != nil { 81 fmt.Fprintf(os.Stderr, "DNS: %v\n", err) 82 os.Exit(1) 83 } 84 85 conn, err := net.DialTimeout("tcp", dest, timeout) 86 if err == nil { 87 conn.Close() 88 os.Exit(0) 89 } 90 if opErr, ok := err.(*net.OpError); ok { 91 if opErr.Timeout() { 92 fmt.Fprintf(os.Stderr, "TIMEOUT\n") 93 os.Exit(1) 94 } else if syscallErr, ok := opErr.Err.(*os.SyscallError); ok { 95 if syscallErr.Err == syscall.ECONNREFUSED { 96 fmt.Fprintf(os.Stderr, "REFUSED\n") 97 os.Exit(1) 98 } 99 } 100 } 101 102 fmt.Fprintf(os.Stderr, "OTHER: %v\n", err) 103 os.Exit(1) 104} 105 106func connectSCTP(dest string, timeout time.Duration) { 107 addr, err := sctp.ResolveSCTPAddr("sctp", dest) 108 if err != nil { 109 fmt.Fprintf(os.Stderr, "DNS: %v\n", err) 110 os.Exit(1) 111 } 112 113 timeoutCh := time.After(timeout) 114 errCh := make(chan (error)) 115 116 go func() { 117 conn, err := sctp.DialSCTP("sctp", nil, addr) 118 if err == nil { 119 conn.Close() 120 } 121 errCh <- err 122 }() 123 124 select { 125 case err := <-errCh: 126 if err != nil { 127 fmt.Fprintf(os.Stderr, "OTHER: %v\n", err) 128 os.Exit(1) 129 } 130 case <-timeoutCh: 131 fmt.Fprint(os.Stderr, "TIMEOUT\n") 132 os.Exit(1) 133 } 134} 135 136func connectUDP(dest string, timeout time.Duration, data string) { 137 var ( 138 readBytes int 139 buf = make([]byte, 1024) 140 ) 141 142 if _, err := net.ResolveUDPAddr("udp", dest); err != nil { 143 fmt.Fprintf(os.Stderr, "DNS: %v\n", err) 144 os.Exit(1) 145 } 146 147 conn, err := net.Dial("udp", dest) 148 if err != nil { 149 fmt.Fprintf(os.Stderr, "OTHER: %v\n", err) 150 os.Exit(1) 151 } 152 153 if timeout > 0 { 154 if err = conn.SetDeadline(time.Now().Add(timeout)); err != nil { 155 fmt.Fprintf(os.Stderr, "OTHER: %v\n", err) 156 os.Exit(1) 157 } 158 } 159 160 if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", data))); err != nil { 161 parseUDPErrorAndExit(err) 162 } 163 164 if readBytes, err = conn.Read(buf); err != nil { 165 parseUDPErrorAndExit(err) 166 } 167 168 // ensure the response from UDP server 169 if readBytes == 0 { 170 fmt.Fprintf(os.Stderr, "OTHER: No data received from the server. Cannot guarantee the server received the request.\n") 171 os.Exit(1) 172 } 173} 174 175func parseUDPErrorAndExit(err error) { 176 neterr, ok := err.(net.Error) 177 if ok && neterr.Timeout() { 178 fmt.Fprintf(os.Stderr, "TIMEOUT: %v\n", err) 179 } else if strings.Contains(err.Error(), "connection refused") { 180 fmt.Fprintf(os.Stderr, "REFUSED: %v\n", err) 181 } else { 182 fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err) 183 } 184 os.Exit(1) 185} 186