1// Copyright 2011 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package ssh 6 7import ( 8 "bytes" 9 "encoding/binary" 10 "errors" 11 "fmt" 12 "io" 13 "math/big" 14 "reflect" 15 "strconv" 16 "strings" 17) 18 19// These are SSH message type numbers. They are scattered around several 20// documents but many were taken from [SSH-PARAMETERS]. 21const ( 22 msgIgnore = 2 23 msgUnimplemented = 3 24 msgDebug = 4 25 msgNewKeys = 21 26) 27 28// SSH messages: 29// 30// These structures mirror the wire format of the corresponding SSH messages. 31// They are marshaled using reflection with the marshal and unmarshal functions 32// in this file. The only wrinkle is that a final member of type []byte with a 33// ssh tag of "rest" receives the remainder of a packet when unmarshaling. 34 35// See RFC 4253, section 11.1. 36const msgDisconnect = 1 37 38// disconnectMsg is the message that signals a disconnect. It is also 39// the error type returned from mux.Wait() 40type disconnectMsg struct { 41 Reason uint32 `sshtype:"1"` 42 Message string 43 Language string 44} 45 46func (d *disconnectMsg) Error() string { 47 return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) 48} 49 50// See RFC 4253, section 7.1. 51const msgKexInit = 20 52 53type kexInitMsg struct { 54 Cookie [16]byte `sshtype:"20"` 55 KexAlgos []string 56 ServerHostKeyAlgos []string 57 CiphersClientServer []string 58 CiphersServerClient []string 59 MACsClientServer []string 60 MACsServerClient []string 61 CompressionClientServer []string 62 CompressionServerClient []string 63 LanguagesClientServer []string 64 LanguagesServerClient []string 65 FirstKexFollows bool 66 Reserved uint32 67} 68 69// See RFC 4253, section 8. 70 71// Diffie-Helman 72const msgKexDHInit = 30 73 74type kexDHInitMsg struct { 75 X *big.Int `sshtype:"30"` 76} 77 78const msgKexECDHInit = 30 79 80type kexECDHInitMsg struct { 81 ClientPubKey []byte `sshtype:"30"` 82} 83 84const msgKexECDHReply = 31 85 86type kexECDHReplyMsg struct { 87 HostKey []byte `sshtype:"31"` 88 EphemeralPubKey []byte 89 Signature []byte 90} 91 92const msgKexDHReply = 31 93 94type kexDHReplyMsg struct { 95 HostKey []byte `sshtype:"31"` 96 Y *big.Int 97 Signature []byte 98} 99 100// See RFC 4253, section 10. 101const msgServiceRequest = 5 102 103type serviceRequestMsg struct { 104 Service string `sshtype:"5"` 105} 106 107// See RFC 4253, section 10. 108const msgServiceAccept = 6 109 110type serviceAcceptMsg struct { 111 Service string `sshtype:"6"` 112} 113 114// See RFC 4252, section 5. 115const msgUserAuthRequest = 50 116 117type userAuthRequestMsg struct { 118 User string `sshtype:"50"` 119 Service string 120 Method string 121 Payload []byte `ssh:"rest"` 122} 123 124// Used for debug printouts of packets. 125type userAuthSuccessMsg struct { 126} 127 128// See RFC 4252, section 5.1 129const msgUserAuthFailure = 51 130 131type userAuthFailureMsg struct { 132 Methods []string `sshtype:"51"` 133 PartialSuccess bool 134} 135 136// See RFC 4252, section 5.1 137const msgUserAuthSuccess = 52 138 139// See RFC 4252, section 5.4 140const msgUserAuthBanner = 53 141 142type userAuthBannerMsg struct { 143 Message string `sshtype:"53"` 144 // unused, but required to allow message parsing 145 Language string 146} 147 148// See RFC 4256, section 3.2 149const msgUserAuthInfoRequest = 60 150const msgUserAuthInfoResponse = 61 151 152type userAuthInfoRequestMsg struct { 153 User string `sshtype:"60"` 154 Instruction string 155 DeprecatedLanguage string 156 NumPrompts uint32 157 Prompts []byte `ssh:"rest"` 158} 159 160// See RFC 4254, section 5.1. 161const msgChannelOpen = 90 162 163type channelOpenMsg struct { 164 ChanType string `sshtype:"90"` 165 PeersID uint32 166 PeersWindow uint32 167 MaxPacketSize uint32 168 TypeSpecificData []byte `ssh:"rest"` 169} 170 171const msgChannelExtendedData = 95 172const msgChannelData = 94 173 174// Used for debug print outs of packets. 175type channelDataMsg struct { 176 PeersID uint32 `sshtype:"94"` 177 Length uint32 178 Rest []byte `ssh:"rest"` 179} 180 181// See RFC 4254, section 5.1. 182const msgChannelOpenConfirm = 91 183 184type channelOpenConfirmMsg struct { 185 PeersID uint32 `sshtype:"91"` 186 MyID uint32 187 MyWindow uint32 188 MaxPacketSize uint32 189 TypeSpecificData []byte `ssh:"rest"` 190} 191 192// See RFC 4254, section 5.1. 193const msgChannelOpenFailure = 92 194 195type channelOpenFailureMsg struct { 196 PeersID uint32 `sshtype:"92"` 197 Reason RejectionReason 198 Message string 199 Language string 200} 201 202const msgChannelRequest = 98 203 204type channelRequestMsg struct { 205 PeersID uint32 `sshtype:"98"` 206 Request string 207 WantReply bool 208 RequestSpecificData []byte `ssh:"rest"` 209} 210 211// See RFC 4254, section 5.4. 212const msgChannelSuccess = 99 213 214type channelRequestSuccessMsg struct { 215 PeersID uint32 `sshtype:"99"` 216} 217 218// See RFC 4254, section 5.4. 219const msgChannelFailure = 100 220 221type channelRequestFailureMsg struct { 222 PeersID uint32 `sshtype:"100"` 223} 224 225// See RFC 4254, section 5.3 226const msgChannelClose = 97 227 228type channelCloseMsg struct { 229 PeersID uint32 `sshtype:"97"` 230} 231 232// See RFC 4254, section 5.3 233const msgChannelEOF = 96 234 235type channelEOFMsg struct { 236 PeersID uint32 `sshtype:"96"` 237} 238 239// See RFC 4254, section 4 240const msgGlobalRequest = 80 241 242type globalRequestMsg struct { 243 Type string `sshtype:"80"` 244 WantReply bool 245 Data []byte `ssh:"rest"` 246} 247 248// See RFC 4254, section 4 249const msgRequestSuccess = 81 250 251type globalRequestSuccessMsg struct { 252 Data []byte `ssh:"rest" sshtype:"81"` 253} 254 255// See RFC 4254, section 4 256const msgRequestFailure = 82 257 258type globalRequestFailureMsg struct { 259 Data []byte `ssh:"rest" sshtype:"82"` 260} 261 262// See RFC 4254, section 5.2 263const msgChannelWindowAdjust = 93 264 265type windowAdjustMsg struct { 266 PeersID uint32 `sshtype:"93"` 267 AdditionalBytes uint32 268} 269 270// See RFC 4252, section 7 271const msgUserAuthPubKeyOk = 60 272 273type userAuthPubKeyOkMsg struct { 274 Algo string `sshtype:"60"` 275 PubKey []byte 276} 277 278// typeTags returns the possible type bytes for the given reflect.Type, which 279// should be a struct. The possible values are separated by a '|' character. 280func typeTags(structType reflect.Type) (tags []byte) { 281 tagStr := structType.Field(0).Tag.Get("sshtype") 282 283 for _, tag := range strings.Split(tagStr, "|") { 284 i, err := strconv.Atoi(tag) 285 if err == nil { 286 tags = append(tags, byte(i)) 287 } 288 } 289 290 return tags 291} 292 293func fieldError(t reflect.Type, field int, problem string) error { 294 if problem != "" { 295 problem = ": " + problem 296 } 297 return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) 298} 299 300var errShortRead = errors.New("ssh: short read") 301 302// Unmarshal parses data in SSH wire format into a structure. The out 303// argument should be a pointer to struct. If the first member of the 304// struct has the "sshtype" tag set to a '|'-separated set of numbers 305// in decimal, the packet must start with one of those numbers. In 306// case of error, Unmarshal returns a ParseError or 307// UnexpectedMessageError. 308func Unmarshal(data []byte, out interface{}) error { 309 v := reflect.ValueOf(out).Elem() 310 structType := v.Type() 311 expectedTypes := typeTags(structType) 312 313 var expectedType byte 314 if len(expectedTypes) > 0 { 315 expectedType = expectedTypes[0] 316 } 317 318 if len(data) == 0 { 319 return parseError(expectedType) 320 } 321 322 if len(expectedTypes) > 0 { 323 goodType := false 324 for _, e := range expectedTypes { 325 if e > 0 && data[0] == e { 326 goodType = true 327 break 328 } 329 } 330 if !goodType { 331 return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes) 332 } 333 data = data[1:] 334 } 335 336 var ok bool 337 for i := 0; i < v.NumField(); i++ { 338 field := v.Field(i) 339 t := field.Type() 340 switch t.Kind() { 341 case reflect.Bool: 342 if len(data) < 1 { 343 return errShortRead 344 } 345 field.SetBool(data[0] != 0) 346 data = data[1:] 347 case reflect.Array: 348 if t.Elem().Kind() != reflect.Uint8 { 349 return fieldError(structType, i, "array of unsupported type") 350 } 351 if len(data) < t.Len() { 352 return errShortRead 353 } 354 for j, n := 0, t.Len(); j < n; j++ { 355 field.Index(j).Set(reflect.ValueOf(data[j])) 356 } 357 data = data[t.Len():] 358 case reflect.Uint64: 359 var u64 uint64 360 if u64, data, ok = parseUint64(data); !ok { 361 return errShortRead 362 } 363 field.SetUint(u64) 364 case reflect.Uint32: 365 var u32 uint32 366 if u32, data, ok = parseUint32(data); !ok { 367 return errShortRead 368 } 369 field.SetUint(uint64(u32)) 370 case reflect.Uint8: 371 if len(data) < 1 { 372 return errShortRead 373 } 374 field.SetUint(uint64(data[0])) 375 data = data[1:] 376 case reflect.String: 377 var s []byte 378 if s, data, ok = parseString(data); !ok { 379 return fieldError(structType, i, "") 380 } 381 field.SetString(string(s)) 382 case reflect.Slice: 383 switch t.Elem().Kind() { 384 case reflect.Uint8: 385 if structType.Field(i).Tag.Get("ssh") == "rest" { 386 field.Set(reflect.ValueOf(data)) 387 data = nil 388 } else { 389 var s []byte 390 if s, data, ok = parseString(data); !ok { 391 return errShortRead 392 } 393 field.Set(reflect.ValueOf(s)) 394 } 395 case reflect.String: 396 var nl []string 397 if nl, data, ok = parseNameList(data); !ok { 398 return errShortRead 399 } 400 field.Set(reflect.ValueOf(nl)) 401 default: 402 return fieldError(structType, i, "slice of unsupported type") 403 } 404 case reflect.Ptr: 405 if t == bigIntType { 406 var n *big.Int 407 if n, data, ok = parseInt(data); !ok { 408 return errShortRead 409 } 410 field.Set(reflect.ValueOf(n)) 411 } else { 412 return fieldError(structType, i, "pointer to unsupported type") 413 } 414 default: 415 return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t)) 416 } 417 } 418 419 if len(data) != 0 { 420 return parseError(expectedType) 421 } 422 423 return nil 424} 425 426// Marshal serializes the message in msg to SSH wire format. The msg 427// argument should be a struct or pointer to struct. If the first 428// member has the "sshtype" tag set to a number in decimal, that 429// number is prepended to the result. If the last of member has the 430// "ssh" tag set to "rest", its contents are appended to the output. 431func Marshal(msg interface{}) []byte { 432 out := make([]byte, 0, 64) 433 return marshalStruct(out, msg) 434} 435 436func marshalStruct(out []byte, msg interface{}) []byte { 437 v := reflect.Indirect(reflect.ValueOf(msg)) 438 msgTypes := typeTags(v.Type()) 439 if len(msgTypes) > 0 { 440 out = append(out, msgTypes[0]) 441 } 442 443 for i, n := 0, v.NumField(); i < n; i++ { 444 field := v.Field(i) 445 switch t := field.Type(); t.Kind() { 446 case reflect.Bool: 447 var v uint8 448 if field.Bool() { 449 v = 1 450 } 451 out = append(out, v) 452 case reflect.Array: 453 if t.Elem().Kind() != reflect.Uint8 { 454 panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) 455 } 456 for j, l := 0, t.Len(); j < l; j++ { 457 out = append(out, uint8(field.Index(j).Uint())) 458 } 459 case reflect.Uint32: 460 out = appendU32(out, uint32(field.Uint())) 461 case reflect.Uint64: 462 out = appendU64(out, uint64(field.Uint())) 463 case reflect.Uint8: 464 out = append(out, uint8(field.Uint())) 465 case reflect.String: 466 s := field.String() 467 out = appendInt(out, len(s)) 468 out = append(out, s...) 469 case reflect.Slice: 470 switch t.Elem().Kind() { 471 case reflect.Uint8: 472 if v.Type().Field(i).Tag.Get("ssh") != "rest" { 473 out = appendInt(out, field.Len()) 474 } 475 out = append(out, field.Bytes()...) 476 case reflect.String: 477 offset := len(out) 478 out = appendU32(out, 0) 479 if n := field.Len(); n > 0 { 480 for j := 0; j < n; j++ { 481 f := field.Index(j) 482 if j != 0 { 483 out = append(out, ',') 484 } 485 out = append(out, f.String()...) 486 } 487 // overwrite length value 488 binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) 489 } 490 default: 491 panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) 492 } 493 case reflect.Ptr: 494 if t == bigIntType { 495 var n *big.Int 496 nValue := reflect.ValueOf(&n) 497 nValue.Elem().Set(field) 498 needed := intLength(n) 499 oldLength := len(out) 500 501 if cap(out)-len(out) < needed { 502 newOut := make([]byte, len(out), 2*(len(out)+needed)) 503 copy(newOut, out) 504 out = newOut 505 } 506 out = out[:oldLength+needed] 507 marshalInt(out[oldLength:], n) 508 } else { 509 panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) 510 } 511 } 512 } 513 514 return out 515} 516 517var bigOne = big.NewInt(1) 518 519func parseString(in []byte) (out, rest []byte, ok bool) { 520 if len(in) < 4 { 521 return 522 } 523 length := binary.BigEndian.Uint32(in) 524 in = in[4:] 525 if uint32(len(in)) < length { 526 return 527 } 528 out = in[:length] 529 rest = in[length:] 530 ok = true 531 return 532} 533 534var ( 535 comma = []byte{','} 536 emptyNameList = []string{} 537) 538 539func parseNameList(in []byte) (out []string, rest []byte, ok bool) { 540 contents, rest, ok := parseString(in) 541 if !ok { 542 return 543 } 544 if len(contents) == 0 { 545 out = emptyNameList 546 return 547 } 548 parts := bytes.Split(contents, comma) 549 out = make([]string, len(parts)) 550 for i, part := range parts { 551 out[i] = string(part) 552 } 553 return 554} 555 556func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { 557 contents, rest, ok := parseString(in) 558 if !ok { 559 return 560 } 561 out = new(big.Int) 562 563 if len(contents) > 0 && contents[0]&0x80 == 0x80 { 564 // This is a negative number 565 notBytes := make([]byte, len(contents)) 566 for i := range notBytes { 567 notBytes[i] = ^contents[i] 568 } 569 out.SetBytes(notBytes) 570 out.Add(out, bigOne) 571 out.Neg(out) 572 } else { 573 // Positive number 574 out.SetBytes(contents) 575 } 576 ok = true 577 return 578} 579 580func parseUint32(in []byte) (uint32, []byte, bool) { 581 if len(in) < 4 { 582 return 0, nil, false 583 } 584 return binary.BigEndian.Uint32(in), in[4:], true 585} 586 587func parseUint64(in []byte) (uint64, []byte, bool) { 588 if len(in) < 8 { 589 return 0, nil, false 590 } 591 return binary.BigEndian.Uint64(in), in[8:], true 592} 593 594func intLength(n *big.Int) int { 595 length := 4 /* length bytes */ 596 if n.Sign() < 0 { 597 nMinus1 := new(big.Int).Neg(n) 598 nMinus1.Sub(nMinus1, bigOne) 599 bitLen := nMinus1.BitLen() 600 if bitLen%8 == 0 { 601 // The number will need 0xff padding 602 length++ 603 } 604 length += (bitLen + 7) / 8 605 } else if n.Sign() == 0 { 606 // A zero is the zero length string 607 } else { 608 bitLen := n.BitLen() 609 if bitLen%8 == 0 { 610 // The number will need 0x00 padding 611 length++ 612 } 613 length += (bitLen + 7) / 8 614 } 615 616 return length 617} 618 619func marshalUint32(to []byte, n uint32) []byte { 620 binary.BigEndian.PutUint32(to, n) 621 return to[4:] 622} 623 624func marshalUint64(to []byte, n uint64) []byte { 625 binary.BigEndian.PutUint64(to, n) 626 return to[8:] 627} 628 629func marshalInt(to []byte, n *big.Int) []byte { 630 lengthBytes := to 631 to = to[4:] 632 length := 0 633 634 if n.Sign() < 0 { 635 // A negative number has to be converted to two's-complement 636 // form. So we'll subtract 1 and invert. If the 637 // most-significant-bit isn't set then we'll need to pad the 638 // beginning with 0xff in order to keep the number negative. 639 nMinus1 := new(big.Int).Neg(n) 640 nMinus1.Sub(nMinus1, bigOne) 641 bytes := nMinus1.Bytes() 642 for i := range bytes { 643 bytes[i] ^= 0xff 644 } 645 if len(bytes) == 0 || bytes[0]&0x80 == 0 { 646 to[0] = 0xff 647 to = to[1:] 648 length++ 649 } 650 nBytes := copy(to, bytes) 651 to = to[nBytes:] 652 length += nBytes 653 } else if n.Sign() == 0 { 654 // A zero is the zero length string 655 } else { 656 bytes := n.Bytes() 657 if len(bytes) > 0 && bytes[0]&0x80 != 0 { 658 // We'll have to pad this with a 0x00 in order to 659 // stop it looking like a negative number. 660 to[0] = 0 661 to = to[1:] 662 length++ 663 } 664 nBytes := copy(to, bytes) 665 to = to[nBytes:] 666 length += nBytes 667 } 668 669 lengthBytes[0] = byte(length >> 24) 670 lengthBytes[1] = byte(length >> 16) 671 lengthBytes[2] = byte(length >> 8) 672 lengthBytes[3] = byte(length) 673 return to 674} 675 676func writeInt(w io.Writer, n *big.Int) { 677 length := intLength(n) 678 buf := make([]byte, length) 679 marshalInt(buf, n) 680 w.Write(buf) 681} 682 683func writeString(w io.Writer, s []byte) { 684 var lengthBytes [4]byte 685 lengthBytes[0] = byte(len(s) >> 24) 686 lengthBytes[1] = byte(len(s) >> 16) 687 lengthBytes[2] = byte(len(s) >> 8) 688 lengthBytes[3] = byte(len(s)) 689 w.Write(lengthBytes[:]) 690 w.Write(s) 691} 692 693func stringLength(n int) int { 694 return 4 + n 695} 696 697func marshalString(to []byte, s []byte) []byte { 698 to[0] = byte(len(s) >> 24) 699 to[1] = byte(len(s) >> 16) 700 to[2] = byte(len(s) >> 8) 701 to[3] = byte(len(s)) 702 to = to[4:] 703 copy(to, s) 704 return to[len(s):] 705} 706 707var bigIntType = reflect.TypeOf((*big.Int)(nil)) 708 709// Decode a packet into its corresponding message. 710func decode(packet []byte) (interface{}, error) { 711 var msg interface{} 712 switch packet[0] { 713 case msgDisconnect: 714 msg = new(disconnectMsg) 715 case msgServiceRequest: 716 msg = new(serviceRequestMsg) 717 case msgServiceAccept: 718 msg = new(serviceAcceptMsg) 719 case msgKexInit: 720 msg = new(kexInitMsg) 721 case msgKexDHInit: 722 msg = new(kexDHInitMsg) 723 case msgKexDHReply: 724 msg = new(kexDHReplyMsg) 725 case msgUserAuthRequest: 726 msg = new(userAuthRequestMsg) 727 case msgUserAuthSuccess: 728 return new(userAuthSuccessMsg), nil 729 case msgUserAuthFailure: 730 msg = new(userAuthFailureMsg) 731 case msgUserAuthPubKeyOk: 732 msg = new(userAuthPubKeyOkMsg) 733 case msgGlobalRequest: 734 msg = new(globalRequestMsg) 735 case msgRequestSuccess: 736 msg = new(globalRequestSuccessMsg) 737 case msgRequestFailure: 738 msg = new(globalRequestFailureMsg) 739 case msgChannelOpen: 740 msg = new(channelOpenMsg) 741 case msgChannelData: 742 msg = new(channelDataMsg) 743 case msgChannelOpenConfirm: 744 msg = new(channelOpenConfirmMsg) 745 case msgChannelOpenFailure: 746 msg = new(channelOpenFailureMsg) 747 case msgChannelWindowAdjust: 748 msg = new(windowAdjustMsg) 749 case msgChannelEOF: 750 msg = new(channelEOFMsg) 751 case msgChannelClose: 752 msg = new(channelCloseMsg) 753 case msgChannelRequest: 754 msg = new(channelRequestMsg) 755 case msgChannelSuccess: 756 msg = new(channelRequestSuccessMsg) 757 case msgChannelFailure: 758 msg = new(channelRequestFailureMsg) 759 default: 760 return nil, unexpectedMessageError(0, packet[0]) 761 } 762 if err := Unmarshal(packet, msg); err != nil { 763 return nil, err 764 } 765 return msg, nil 766} 767