1// Copyright 2013 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package ssh 6 7import ( 8 "crypto/rand" 9 "errors" 10 "fmt" 11 "io" 12 "log" 13 "net" 14 "sync" 15) 16 17// debugHandshake, if set, prints messages sent and received. Key 18// exchange messages are printed as if DH were used, so the debug 19// messages are wrong when using ECDH. 20const debugHandshake = false 21 22// chanSize sets the amount of buffering SSH connections. This is 23// primarily for testing: setting chanSize=0 uncovers deadlocks more 24// quickly. 25const chanSize = 16 26 27// keyingTransport is a packet based transport that supports key 28// changes. It need not be thread-safe. It should pass through 29// msgNewKeys in both directions. 30type keyingTransport interface { 31 packetConn 32 33 // prepareKeyChange sets up a key change. The key change for a 34 // direction will be effected if a msgNewKeys message is sent 35 // or received. 36 prepareKeyChange(*algorithms, *kexResult) error 37} 38 39// handshakeTransport implements rekeying on top of a keyingTransport 40// and offers a thread-safe writePacket() interface. 41type handshakeTransport struct { 42 conn keyingTransport 43 config *Config 44 45 serverVersion []byte 46 clientVersion []byte 47 48 // hostKeys is non-empty if we are the server. In that case, 49 // it contains all host keys that can be used to sign the 50 // connection. 51 hostKeys []Signer 52 53 // hostKeyAlgorithms is non-empty if we are the client. In that case, 54 // we accept these key types from the server as host key. 55 hostKeyAlgorithms []string 56 57 // On read error, incoming is closed, and readError is set. 58 incoming chan []byte 59 readError error 60 61 mu sync.Mutex 62 writeError error 63 sentInitPacket []byte 64 sentInitMsg *kexInitMsg 65 pendingPackets [][]byte // Used when a key exchange is in progress. 66 67 // If the read loop wants to schedule a kex, it pings this 68 // channel, and the write loop will send out a kex 69 // message. 70 requestKex chan struct{} 71 72 // If the other side requests or confirms a kex, its kexInit 73 // packet is sent here for the write loop to find it. 74 startKex chan *pendingKex 75 76 // data for host key checking 77 hostKeyCallback HostKeyCallback 78 dialAddress string 79 remoteAddr net.Addr 80 81 // bannerCallback is non-empty if we are the client and it has been set in 82 // ClientConfig. In that case it is called during the user authentication 83 // dance to handle a custom server's message. 84 bannerCallback BannerCallback 85 86 // Algorithms agreed in the last key exchange. 87 algorithms *algorithms 88 89 readPacketsLeft uint32 90 readBytesLeft int64 91 92 writePacketsLeft uint32 93 writeBytesLeft int64 94 95 // The session ID or nil if first kex did not complete yet. 96 sessionID []byte 97} 98 99type pendingKex struct { 100 otherInit []byte 101 done chan error 102} 103 104func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { 105 t := &handshakeTransport{ 106 conn: conn, 107 serverVersion: serverVersion, 108 clientVersion: clientVersion, 109 incoming: make(chan []byte, chanSize), 110 requestKex: make(chan struct{}, 1), 111 startKex: make(chan *pendingKex, 1), 112 113 config: config, 114 } 115 t.resetReadThresholds() 116 t.resetWriteThresholds() 117 118 // We always start with a mandatory key exchange. 119 t.requestKex <- struct{}{} 120 return t 121} 122 123func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { 124 t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) 125 t.dialAddress = dialAddr 126 t.remoteAddr = addr 127 t.hostKeyCallback = config.HostKeyCallback 128 t.bannerCallback = config.BannerCallback 129 if config.HostKeyAlgorithms != nil { 130 t.hostKeyAlgorithms = config.HostKeyAlgorithms 131 } else { 132 t.hostKeyAlgorithms = supportedHostKeyAlgos 133 } 134 go t.readLoop() 135 go t.kexLoop() 136 return t 137} 138 139func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { 140 t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) 141 t.hostKeys = config.hostKeys 142 go t.readLoop() 143 go t.kexLoop() 144 return t 145} 146 147func (t *handshakeTransport) getSessionID() []byte { 148 return t.sessionID 149} 150 151// waitSession waits for the session to be established. This should be 152// the first thing to call after instantiating handshakeTransport. 153func (t *handshakeTransport) waitSession() error { 154 p, err := t.readPacket() 155 if err != nil { 156 return err 157 } 158 if p[0] != msgNewKeys { 159 return fmt.Errorf("ssh: first packet should be msgNewKeys") 160 } 161 162 return nil 163} 164 165func (t *handshakeTransport) id() string { 166 if len(t.hostKeys) > 0 { 167 return "server" 168 } 169 return "client" 170} 171 172func (t *handshakeTransport) printPacket(p []byte, write bool) { 173 action := "got" 174 if write { 175 action = "sent" 176 } 177 178 if p[0] == msgChannelData || p[0] == msgChannelExtendedData { 179 log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p)) 180 } else { 181 msg, err := decode(p) 182 log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err) 183 } 184} 185 186func (t *handshakeTransport) readPacket() ([]byte, error) { 187 p, ok := <-t.incoming 188 if !ok { 189 return nil, t.readError 190 } 191 return p, nil 192} 193 194func (t *handshakeTransport) readLoop() { 195 first := true 196 for { 197 p, err := t.readOnePacket(first) 198 first = false 199 if err != nil { 200 t.readError = err 201 close(t.incoming) 202 break 203 } 204 if p[0] == msgIgnore || p[0] == msgDebug { 205 continue 206 } 207 t.incoming <- p 208 } 209 210 // Stop writers too. 211 t.recordWriteError(t.readError) 212 213 // Unblock the writer should it wait for this. 214 close(t.startKex) 215 216 // Don't close t.requestKex; it's also written to from writePacket. 217} 218 219func (t *handshakeTransport) pushPacket(p []byte) error { 220 if debugHandshake { 221 t.printPacket(p, true) 222 } 223 return t.conn.writePacket(p) 224} 225 226func (t *handshakeTransport) getWriteError() error { 227 t.mu.Lock() 228 defer t.mu.Unlock() 229 return t.writeError 230} 231 232func (t *handshakeTransport) recordWriteError(err error) { 233 t.mu.Lock() 234 defer t.mu.Unlock() 235 if t.writeError == nil && err != nil { 236 t.writeError = err 237 } 238} 239 240func (t *handshakeTransport) requestKeyExchange() { 241 select { 242 case t.requestKex <- struct{}{}: 243 default: 244 // something already requested a kex, so do nothing. 245 } 246} 247 248func (t *handshakeTransport) resetWriteThresholds() { 249 t.writePacketsLeft = packetRekeyThreshold 250 if t.config.RekeyThreshold > 0 { 251 t.writeBytesLeft = int64(t.config.RekeyThreshold) 252 } else if t.algorithms != nil { 253 t.writeBytesLeft = t.algorithms.w.rekeyBytes() 254 } else { 255 t.writeBytesLeft = 1 << 30 256 } 257} 258 259func (t *handshakeTransport) kexLoop() { 260 261write: 262 for t.getWriteError() == nil { 263 var request *pendingKex 264 var sent bool 265 266 for request == nil || !sent { 267 var ok bool 268 select { 269 case request, ok = <-t.startKex: 270 if !ok { 271 break write 272 } 273 case <-t.requestKex: 274 break 275 } 276 277 if !sent { 278 if err := t.sendKexInit(); err != nil { 279 t.recordWriteError(err) 280 break 281 } 282 sent = true 283 } 284 } 285 286 if err := t.getWriteError(); err != nil { 287 if request != nil { 288 request.done <- err 289 } 290 break 291 } 292 293 // We're not servicing t.requestKex, but that is OK: 294 // we never block on sending to t.requestKex. 295 296 // We're not servicing t.startKex, but the remote end 297 // has just sent us a kexInitMsg, so it can't send 298 // another key change request, until we close the done 299 // channel on the pendingKex request. 300 301 err := t.enterKeyExchange(request.otherInit) 302 303 t.mu.Lock() 304 t.writeError = err 305 t.sentInitPacket = nil 306 t.sentInitMsg = nil 307 308 t.resetWriteThresholds() 309 310 // we have completed the key exchange. Since the 311 // reader is still blocked, it is safe to clear out 312 // the requestKex channel. This avoids the situation 313 // where: 1) we consumed our own request for the 314 // initial kex, and 2) the kex from the remote side 315 // caused another send on the requestKex channel, 316 clear: 317 for { 318 select { 319 case <-t.requestKex: 320 // 321 default: 322 break clear 323 } 324 } 325 326 request.done <- t.writeError 327 328 // kex finished. Push packets that we received while 329 // the kex was in progress. Don't look at t.startKex 330 // and don't increment writtenSinceKex: if we trigger 331 // another kex while we are still busy with the last 332 // one, things will become very confusing. 333 for _, p := range t.pendingPackets { 334 t.writeError = t.pushPacket(p) 335 if t.writeError != nil { 336 break 337 } 338 } 339 t.pendingPackets = t.pendingPackets[:0] 340 t.mu.Unlock() 341 } 342 343 // drain startKex channel. We don't service t.requestKex 344 // because nobody does blocking sends there. 345 go func() { 346 for init := range t.startKex { 347 init.done <- t.writeError 348 } 349 }() 350 351 // Unblock reader. 352 t.conn.Close() 353} 354 355// The protocol uses uint32 for packet counters, so we can't let them 356// reach 1<<32. We will actually read and write more packets than 357// this, though: the other side may send more packets, and after we 358// hit this limit on writing we will send a few more packets for the 359// key exchange itself. 360const packetRekeyThreshold = (1 << 31) 361 362func (t *handshakeTransport) resetReadThresholds() { 363 t.readPacketsLeft = packetRekeyThreshold 364 if t.config.RekeyThreshold > 0 { 365 t.readBytesLeft = int64(t.config.RekeyThreshold) 366 } else if t.algorithms != nil { 367 t.readBytesLeft = t.algorithms.r.rekeyBytes() 368 } else { 369 t.readBytesLeft = 1 << 30 370 } 371} 372 373func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { 374 p, err := t.conn.readPacket() 375 if err != nil { 376 return nil, err 377 } 378 379 if t.readPacketsLeft > 0 { 380 t.readPacketsLeft-- 381 } else { 382 t.requestKeyExchange() 383 } 384 385 if t.readBytesLeft > 0 { 386 t.readBytesLeft -= int64(len(p)) 387 } else { 388 t.requestKeyExchange() 389 } 390 391 if debugHandshake { 392 t.printPacket(p, false) 393 } 394 395 if first && p[0] != msgKexInit { 396 return nil, fmt.Errorf("ssh: first packet should be msgKexInit") 397 } 398 399 if p[0] != msgKexInit { 400 return p, nil 401 } 402 403 firstKex := t.sessionID == nil 404 405 kex := pendingKex{ 406 done: make(chan error, 1), 407 otherInit: p, 408 } 409 t.startKex <- &kex 410 err = <-kex.done 411 412 if debugHandshake { 413 log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) 414 } 415 416 if err != nil { 417 return nil, err 418 } 419 420 t.resetReadThresholds() 421 422 // By default, a key exchange is hidden from higher layers by 423 // translating it into msgIgnore. 424 successPacket := []byte{msgIgnore} 425 if firstKex { 426 // sendKexInit() for the first kex waits for 427 // msgNewKeys so the authentication process is 428 // guaranteed to happen over an encrypted transport. 429 successPacket = []byte{msgNewKeys} 430 } 431 432 return successPacket, nil 433} 434 435// sendKexInit sends a key change message. 436func (t *handshakeTransport) sendKexInit() error { 437 t.mu.Lock() 438 defer t.mu.Unlock() 439 if t.sentInitMsg != nil { 440 // kexInits may be sent either in response to the other side, 441 // or because our side wants to initiate a key change, so we 442 // may have already sent a kexInit. In that case, don't send a 443 // second kexInit. 444 return nil 445 } 446 447 msg := &kexInitMsg{ 448 KexAlgos: t.config.KeyExchanges, 449 CiphersClientServer: t.config.Ciphers, 450 CiphersServerClient: t.config.Ciphers, 451 MACsClientServer: t.config.MACs, 452 MACsServerClient: t.config.MACs, 453 CompressionClientServer: supportedCompressions, 454 CompressionServerClient: supportedCompressions, 455 } 456 io.ReadFull(rand.Reader, msg.Cookie[:]) 457 458 if len(t.hostKeys) > 0 { 459 for _, k := range t.hostKeys { 460 msg.ServerHostKeyAlgos = append( 461 msg.ServerHostKeyAlgos, k.PublicKey().Type()) 462 } 463 } else { 464 msg.ServerHostKeyAlgos = t.hostKeyAlgorithms 465 } 466 packet := Marshal(msg) 467 468 // writePacket destroys the contents, so save a copy. 469 packetCopy := make([]byte, len(packet)) 470 copy(packetCopy, packet) 471 472 if err := t.pushPacket(packetCopy); err != nil { 473 return err 474 } 475 476 t.sentInitMsg = msg 477 t.sentInitPacket = packet 478 479 return nil 480} 481 482func (t *handshakeTransport) writePacket(p []byte) error { 483 switch p[0] { 484 case msgKexInit: 485 return errors.New("ssh: only handshakeTransport can send kexInit") 486 case msgNewKeys: 487 return errors.New("ssh: only handshakeTransport can send newKeys") 488 } 489 490 t.mu.Lock() 491 defer t.mu.Unlock() 492 if t.writeError != nil { 493 return t.writeError 494 } 495 496 if t.sentInitMsg != nil { 497 // Copy the packet so the writer can reuse the buffer. 498 cp := make([]byte, len(p)) 499 copy(cp, p) 500 t.pendingPackets = append(t.pendingPackets, cp) 501 return nil 502 } 503 504 if t.writeBytesLeft > 0 { 505 t.writeBytesLeft -= int64(len(p)) 506 } else { 507 t.requestKeyExchange() 508 } 509 510 if t.writePacketsLeft > 0 { 511 t.writePacketsLeft-- 512 } else { 513 t.requestKeyExchange() 514 } 515 516 if err := t.pushPacket(p); err != nil { 517 t.writeError = err 518 } 519 520 return nil 521} 522 523func (t *handshakeTransport) Close() error { 524 return t.conn.Close() 525} 526 527func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { 528 if debugHandshake { 529 log.Printf("%s entered key exchange", t.id()) 530 } 531 532 otherInit := &kexInitMsg{} 533 if err := Unmarshal(otherInitPacket, otherInit); err != nil { 534 return err 535 } 536 537 magics := handshakeMagics{ 538 clientVersion: t.clientVersion, 539 serverVersion: t.serverVersion, 540 clientKexInit: otherInitPacket, 541 serverKexInit: t.sentInitPacket, 542 } 543 544 clientInit := otherInit 545 serverInit := t.sentInitMsg 546 if len(t.hostKeys) == 0 { 547 clientInit, serverInit = serverInit, clientInit 548 549 magics.clientKexInit = t.sentInitPacket 550 magics.serverKexInit = otherInitPacket 551 } 552 553 var err error 554 t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit) 555 if err != nil { 556 return err 557 } 558 559 // We don't send FirstKexFollows, but we handle receiving it. 560 // 561 // RFC 4253 section 7 defines the kex and the agreement method for 562 // first_kex_packet_follows. It states that the guessed packet 563 // should be ignored if the "kex algorithm and/or the host 564 // key algorithm is guessed wrong (server and client have 565 // different preferred algorithm), or if any of the other 566 // algorithms cannot be agreed upon". The other algorithms have 567 // already been checked above so the kex algorithm and host key 568 // algorithm are checked here. 569 if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) { 570 // other side sent a kex message for the wrong algorithm, 571 // which we have to ignore. 572 if _, err := t.conn.readPacket(); err != nil { 573 return err 574 } 575 } 576 577 kex, ok := kexAlgoMap[t.algorithms.kex] 578 if !ok { 579 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex) 580 } 581 582 var result *kexResult 583 if len(t.hostKeys) > 0 { 584 result, err = t.server(kex, t.algorithms, &magics) 585 } else { 586 result, err = t.client(kex, t.algorithms, &magics) 587 } 588 589 if err != nil { 590 return err 591 } 592 593 if t.sessionID == nil { 594 t.sessionID = result.H 595 } 596 result.SessionID = t.sessionID 597 598 if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil { 599 return err 600 } 601 if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { 602 return err 603 } 604 if packet, err := t.conn.readPacket(); err != nil { 605 return err 606 } else if packet[0] != msgNewKeys { 607 return unexpectedMessageError(msgNewKeys, packet[0]) 608 } 609 610 return nil 611} 612 613func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { 614 var hostKey Signer 615 for _, k := range t.hostKeys { 616 if algs.hostKey == k.PublicKey().Type() { 617 hostKey = k 618 } 619 } 620 621 r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) 622 return r, err 623} 624 625func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { 626 result, err := kex.Client(t.conn, t.config.Rand, magics) 627 if err != nil { 628 return nil, err 629 } 630 631 hostKey, err := ParsePublicKey(result.HostKey) 632 if err != nil { 633 return nil, err 634 } 635 636 if err := verifyHostKeySignature(hostKey, result); err != nil { 637 return nil, err 638 } 639 640 err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) 641 if err != nil { 642 return nil, err 643 } 644 645 return result, nil 646} 647