1// Copyright 2013 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package ssh 6 7import ( 8 "bytes" 9 "crypto/rand" 10 "errors" 11 "fmt" 12 "net" 13 "runtime" 14 "strings" 15 "sync" 16 "testing" 17) 18 19type testChecker struct { 20 calls []string 21} 22 23func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 24 if dialAddr == "bad" { 25 return fmt.Errorf("dialAddr is bad") 26 } 27 28 if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { 29 return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) 30 } 31 32 t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) 33 34 return nil 35} 36 37// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and 38// therefore is buffered (net.Pipe deadlocks if both sides start with 39// a write.) 40func netPipe() (net.Conn, net.Conn, error) { 41 listener, err := net.Listen("tcp", "127.0.0.1:0") 42 if err != nil { 43 return nil, nil, err 44 } 45 defer listener.Close() 46 c1, err := net.Dial("tcp", listener.Addr().String()) 47 if err != nil { 48 return nil, nil, err 49 } 50 51 c2, err := listener.Accept() 52 if err != nil { 53 c1.Close() 54 return nil, nil, err 55 } 56 57 return c1, c2, nil 58} 59 60func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { 61 a, b, err := netPipe() 62 if err != nil { 63 return nil, nil, err 64 } 65 66 trC := newTransport(a, rand.Reader, true) 67 trS := newTransport(b, rand.Reader, false) 68 clientConf.SetDefaults() 69 70 v := []byte("version") 71 client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) 72 73 serverConf := &ServerConfig{} 74 serverConf.AddHostKey(testSigners["ecdsa"]) 75 serverConf.AddHostKey(testSigners["rsa"]) 76 serverConf.SetDefaults() 77 server = newServerTransport(trS, v, v, serverConf) 78 79 return client, server, nil 80} 81 82func TestHandshakeBasic(t *testing.T) { 83 if runtime.GOOS == "plan9" { 84 t.Skip("see golang.org/issue/7237") 85 } 86 checker := &testChecker{} 87 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") 88 if err != nil { 89 t.Fatalf("handshakePair: %v", err) 90 } 91 92 defer trC.Close() 93 defer trS.Close() 94 95 go func() { 96 // Client writes a bunch of stuff, and does a key 97 // change in the middle. This should not confuse the 98 // handshake in progress 99 for i := 0; i < 10; i++ { 100 p := []byte{msgRequestSuccess, byte(i)} 101 if err := trC.writePacket(p); err != nil { 102 t.Fatalf("sendPacket: %v", err) 103 } 104 if i == 5 { 105 // halfway through, we request a key change. 106 _, _, err := trC.sendKexInit() 107 if err != nil { 108 t.Fatalf("sendKexInit: %v", err) 109 } 110 } 111 } 112 trC.Close() 113 }() 114 115 // Server checks that client messages come in cleanly 116 i := 0 117 for { 118 p, err := trS.readPacket() 119 if err != nil { 120 break 121 } 122 if p[0] == msgNewKeys { 123 continue 124 } 125 want := []byte{msgRequestSuccess, byte(i)} 126 if bytes.Compare(p, want) != 0 { 127 t.Errorf("message %d: got %q, want %q", i, p, want) 128 } 129 i++ 130 } 131 if i != 10 { 132 t.Errorf("received %d messages, want 10.", i) 133 } 134 135 // If all went well, we registered exactly 1 key change. 136 if len(checker.calls) != 1 { 137 t.Fatalf("got %d host key checks, want 1", len(checker.calls)) 138 } 139 140 pub := testSigners["ecdsa"].PublicKey() 141 want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) 142 if want != checker.calls[0] { 143 t.Errorf("got %q want %q for host key check", checker.calls[0], want) 144 } 145} 146 147func TestHandshakeError(t *testing.T) { 148 checker := &testChecker{} 149 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") 150 if err != nil { 151 t.Fatalf("handshakePair: %v", err) 152 } 153 defer trC.Close() 154 defer trS.Close() 155 156 // send a packet 157 packet := []byte{msgRequestSuccess, 42} 158 if err := trC.writePacket(packet); err != nil { 159 t.Errorf("writePacket: %v", err) 160 } 161 162 // Now request a key change. 163 _, _, err = trC.sendKexInit() 164 if err != nil { 165 t.Errorf("sendKexInit: %v", err) 166 } 167 168 // the key change will fail, and afterwards we can't write. 169 if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { 170 t.Errorf("writePacket after botched rekey succeeded.") 171 } 172 173 readback, err := trS.readPacket() 174 if err != nil { 175 t.Fatalf("server closed too soon: %v", err) 176 } 177 if bytes.Compare(readback, packet) != 0 { 178 t.Errorf("got %q want %q", readback, packet) 179 } 180 readback, err = trS.readPacket() 181 if err == nil { 182 t.Errorf("got a message %q after failed key change", readback) 183 } 184} 185 186func TestHandshakeTwice(t *testing.T) { 187 checker := &testChecker{} 188 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") 189 if err != nil { 190 t.Fatalf("handshakePair: %v", err) 191 } 192 193 defer trC.Close() 194 defer trS.Close() 195 196 // send a packet 197 packet := make([]byte, 5) 198 packet[0] = msgRequestSuccess 199 if err := trC.writePacket(packet); err != nil { 200 t.Errorf("writePacket: %v", err) 201 } 202 203 // Now request a key change. 204 _, _, err = trC.sendKexInit() 205 if err != nil { 206 t.Errorf("sendKexInit: %v", err) 207 } 208 209 // Send another packet. Use a fresh one, since writePacket destroys. 210 packet = make([]byte, 5) 211 packet[0] = msgRequestSuccess 212 if err := trC.writePacket(packet); err != nil { 213 t.Errorf("writePacket: %v", err) 214 } 215 216 // 2nd key change. 217 _, _, err = trC.sendKexInit() 218 if err != nil { 219 t.Errorf("sendKexInit: %v", err) 220 } 221 222 packet = make([]byte, 5) 223 packet[0] = msgRequestSuccess 224 if err := trC.writePacket(packet); err != nil { 225 t.Errorf("writePacket: %v", err) 226 } 227 228 packet = make([]byte, 5) 229 packet[0] = msgRequestSuccess 230 for i := 0; i < 5; i++ { 231 msg, err := trS.readPacket() 232 if err != nil { 233 t.Fatalf("server closed too soon: %v", err) 234 } 235 if msg[0] == msgNewKeys { 236 continue 237 } 238 239 if bytes.Compare(msg, packet) != 0 { 240 t.Errorf("packet %d: got %q want %q", i, msg, packet) 241 } 242 } 243 if len(checker.calls) != 2 { 244 t.Errorf("got %d key changes, want 2", len(checker.calls)) 245 } 246} 247 248func TestHandshakeAutoRekeyWrite(t *testing.T) { 249 checker := &testChecker{} 250 clientConf := &ClientConfig{HostKeyCallback: checker.Check} 251 clientConf.RekeyThreshold = 500 252 trC, trS, err := handshakePair(clientConf, "addr") 253 if err != nil { 254 t.Fatalf("handshakePair: %v", err) 255 } 256 defer trC.Close() 257 defer trS.Close() 258 259 for i := 0; i < 5; i++ { 260 packet := make([]byte, 251) 261 packet[0] = msgRequestSuccess 262 if err := trC.writePacket(packet); err != nil { 263 t.Errorf("writePacket: %v", err) 264 } 265 } 266 267 j := 0 268 for ; j < 5; j++ { 269 _, err := trS.readPacket() 270 if err != nil { 271 break 272 } 273 } 274 275 if j != 5 { 276 t.Errorf("got %d, want 5 messages", j) 277 } 278 279 if len(checker.calls) != 2 { 280 t.Errorf("got %d key changes, wanted 2", len(checker.calls)) 281 } 282} 283 284type syncChecker struct { 285 called chan int 286} 287 288func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 289 t.called <- 1 290 return nil 291} 292 293func TestHandshakeAutoRekeyRead(t *testing.T) { 294 sync := &syncChecker{make(chan int, 2)} 295 clientConf := &ClientConfig{ 296 HostKeyCallback: sync.Check, 297 } 298 clientConf.RekeyThreshold = 500 299 300 trC, trS, err := handshakePair(clientConf, "addr") 301 if err != nil { 302 t.Fatalf("handshakePair: %v", err) 303 } 304 defer trC.Close() 305 defer trS.Close() 306 307 packet := make([]byte, 501) 308 packet[0] = msgRequestSuccess 309 if err := trS.writePacket(packet); err != nil { 310 t.Fatalf("writePacket: %v", err) 311 } 312 // While we read out the packet, a key change will be 313 // initiated. 314 if _, err := trC.readPacket(); err != nil { 315 t.Fatalf("readPacket(client): %v", err) 316 } 317 318 <-sync.called 319} 320 321// errorKeyingTransport generates errors after a given number of 322// read/write operations. 323type errorKeyingTransport struct { 324 packetConn 325 readLeft, writeLeft int 326} 327 328func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { 329 return nil 330} 331func (n *errorKeyingTransport) getSessionID() []byte { 332 return nil 333} 334 335func (n *errorKeyingTransport) writePacket(packet []byte) error { 336 if n.writeLeft == 0 { 337 n.Close() 338 return errors.New("barf") 339 } 340 341 n.writeLeft-- 342 return n.packetConn.writePacket(packet) 343} 344 345func (n *errorKeyingTransport) readPacket() ([]byte, error) { 346 if n.readLeft == 0 { 347 n.Close() 348 return nil, errors.New("barf") 349 } 350 351 n.readLeft-- 352 return n.packetConn.readPacket() 353} 354 355func TestHandshakeErrorHandlingRead(t *testing.T) { 356 for i := 0; i < 20; i++ { 357 testHandshakeErrorHandlingN(t, i, -1) 358 } 359} 360 361func TestHandshakeErrorHandlingWrite(t *testing.T) { 362 for i := 0; i < 20; i++ { 363 testHandshakeErrorHandlingN(t, -1, i) 364 } 365} 366 367// testHandshakeErrorHandlingN runs handshakes, injecting errors. If 368// handshakeTransport deadlocks, the go runtime will detect it and 369// panic. 370func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { 371 msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) 372 373 a, b := memPipe() 374 defer a.Close() 375 defer b.Close() 376 377 key := testSigners["ecdsa"] 378 serverConf := Config{RekeyThreshold: minRekeyThreshold} 379 serverConf.SetDefaults() 380 serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) 381 serverConn.hostKeys = []Signer{key} 382 go serverConn.readLoop() 383 384 clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} 385 clientConf.SetDefaults() 386 clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) 387 clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} 388 go clientConn.readLoop() 389 390 var wg sync.WaitGroup 391 wg.Add(4) 392 393 for _, hs := range []packetConn{serverConn, clientConn} { 394 go func(c packetConn) { 395 for { 396 err := c.writePacket(msg) 397 if err != nil { 398 break 399 } 400 } 401 wg.Done() 402 }(hs) 403 go func(c packetConn) { 404 for { 405 _, err := c.readPacket() 406 if err != nil { 407 break 408 } 409 } 410 wg.Done() 411 }(hs) 412 } 413 414 wg.Wait() 415} 416