1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"bytes"
9	"encoding/binary"
10	"errors"
11	"fmt"
12	"io"
13	"math/big"
14	"reflect"
15	"strconv"
16	"strings"
17)
18
19// These are SSH message type numbers. They are scattered around several
20// documents but many were taken from [SSH-PARAMETERS].
21const (
22	msgIgnore        = 2
23	msgUnimplemented = 3
24	msgDebug         = 4
25	msgNewKeys       = 21
26)
27
28// SSH messages:
29//
30// These structures mirror the wire format of the corresponding SSH messages.
31// They are marshaled using reflection with the marshal and unmarshal functions
32// in this file. The only wrinkle is that a final member of type []byte with a
33// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
34
35// See RFC 4253, section 11.1.
36const msgDisconnect = 1
37
38// disconnectMsg is the message that signals a disconnect. It is also
39// the error type returned from mux.Wait()
40type disconnectMsg struct {
41	Reason   uint32 `sshtype:"1"`
42	Message  string
43	Language string
44}
45
46func (d *disconnectMsg) Error() string {
47	return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
48}
49
50// See RFC 4253, section 7.1.
51const msgKexInit = 20
52
53type kexInitMsg struct {
54	Cookie                  [16]byte `sshtype:"20"`
55	KexAlgos                []string
56	ServerHostKeyAlgos      []string
57	CiphersClientServer     []string
58	CiphersServerClient     []string
59	MACsClientServer        []string
60	MACsServerClient        []string
61	CompressionClientServer []string
62	CompressionServerClient []string
63	LanguagesClientServer   []string
64	LanguagesServerClient   []string
65	FirstKexFollows         bool
66	Reserved                uint32
67}
68
69// See RFC 4253, section 8.
70
71// Diffie-Helman
72const msgKexDHInit = 30
73
74type kexDHInitMsg struct {
75	X *big.Int `sshtype:"30"`
76}
77
78const msgKexECDHInit = 30
79
80type kexECDHInitMsg struct {
81	ClientPubKey []byte `sshtype:"30"`
82}
83
84const msgKexECDHReply = 31
85
86type kexECDHReplyMsg struct {
87	HostKey         []byte `sshtype:"31"`
88	EphemeralPubKey []byte
89	Signature       []byte
90}
91
92const msgKexDHReply = 31
93
94type kexDHReplyMsg struct {
95	HostKey   []byte `sshtype:"31"`
96	Y         *big.Int
97	Signature []byte
98}
99
100// See RFC 4253, section 10.
101const msgServiceRequest = 5
102
103type serviceRequestMsg struct {
104	Service string `sshtype:"5"`
105}
106
107// See RFC 4253, section 10.
108const msgServiceAccept = 6
109
110type serviceAcceptMsg struct {
111	Service string `sshtype:"6"`
112}
113
114// See RFC 4252, section 5.
115const msgUserAuthRequest = 50
116
117type userAuthRequestMsg struct {
118	User    string `sshtype:"50"`
119	Service string
120	Method  string
121	Payload []byte `ssh:"rest"`
122}
123
124// Used for debug printouts of packets.
125type userAuthSuccessMsg struct {
126}
127
128// See RFC 4252, section 5.1
129const msgUserAuthFailure = 51
130
131type userAuthFailureMsg struct {
132	Methods        []string `sshtype:"51"`
133	PartialSuccess bool
134}
135
136// See RFC 4252, section 5.1
137const msgUserAuthSuccess = 52
138
139// See RFC 4252, section 5.4
140const msgUserAuthBanner = 53
141
142type userAuthBannerMsg struct {
143	Message string `sshtype:"53"`
144	// unused, but required to allow message parsing
145	Language string
146}
147
148// See RFC 4256, section 3.2
149const msgUserAuthInfoRequest = 60
150const msgUserAuthInfoResponse = 61
151
152type userAuthInfoRequestMsg struct {
153	User               string `sshtype:"60"`
154	Instruction        string
155	DeprecatedLanguage string
156	NumPrompts         uint32
157	Prompts            []byte `ssh:"rest"`
158}
159
160// See RFC 4254, section 5.1.
161const msgChannelOpen = 90
162
163type channelOpenMsg struct {
164	ChanType         string `sshtype:"90"`
165	PeersID          uint32
166	PeersWindow      uint32
167	MaxPacketSize    uint32
168	TypeSpecificData []byte `ssh:"rest"`
169}
170
171const msgChannelExtendedData = 95
172const msgChannelData = 94
173
174// Used for debug print outs of packets.
175type channelDataMsg struct {
176	PeersID uint32 `sshtype:"94"`
177	Length  uint32
178	Rest    []byte `ssh:"rest"`
179}
180
181// See RFC 4254, section 5.1.
182const msgChannelOpenConfirm = 91
183
184type channelOpenConfirmMsg struct {
185	PeersID          uint32 `sshtype:"91"`
186	MyID             uint32
187	MyWindow         uint32
188	MaxPacketSize    uint32
189	TypeSpecificData []byte `ssh:"rest"`
190}
191
192// See RFC 4254, section 5.1.
193const msgChannelOpenFailure = 92
194
195type channelOpenFailureMsg struct {
196	PeersID  uint32 `sshtype:"92"`
197	Reason   RejectionReason
198	Message  string
199	Language string
200}
201
202const msgChannelRequest = 98
203
204type channelRequestMsg struct {
205	PeersID             uint32 `sshtype:"98"`
206	Request             string
207	WantReply           bool
208	RequestSpecificData []byte `ssh:"rest"`
209}
210
211// See RFC 4254, section 5.4.
212const msgChannelSuccess = 99
213
214type channelRequestSuccessMsg struct {
215	PeersID uint32 `sshtype:"99"`
216}
217
218// See RFC 4254, section 5.4.
219const msgChannelFailure = 100
220
221type channelRequestFailureMsg struct {
222	PeersID uint32 `sshtype:"100"`
223}
224
225// See RFC 4254, section 5.3
226const msgChannelClose = 97
227
228type channelCloseMsg struct {
229	PeersID uint32 `sshtype:"97"`
230}
231
232// See RFC 4254, section 5.3
233const msgChannelEOF = 96
234
235type channelEOFMsg struct {
236	PeersID uint32 `sshtype:"96"`
237}
238
239// See RFC 4254, section 4
240const msgGlobalRequest = 80
241
242type globalRequestMsg struct {
243	Type      string `sshtype:"80"`
244	WantReply bool
245	Data      []byte `ssh:"rest"`
246}
247
248// See RFC 4254, section 4
249const msgRequestSuccess = 81
250
251type globalRequestSuccessMsg struct {
252	Data []byte `ssh:"rest" sshtype:"81"`
253}
254
255// See RFC 4254, section 4
256const msgRequestFailure = 82
257
258type globalRequestFailureMsg struct {
259	Data []byte `ssh:"rest" sshtype:"82"`
260}
261
262// See RFC 4254, section 5.2
263const msgChannelWindowAdjust = 93
264
265type windowAdjustMsg struct {
266	PeersID         uint32 `sshtype:"93"`
267	AdditionalBytes uint32
268}
269
270// See RFC 4252, section 7
271const msgUserAuthPubKeyOk = 60
272
273type userAuthPubKeyOkMsg struct {
274	Algo   string `sshtype:"60"`
275	PubKey []byte
276}
277
278// typeTags returns the possible type bytes for the given reflect.Type, which
279// should be a struct. The possible values are separated by a '|' character.
280func typeTags(structType reflect.Type) (tags []byte) {
281	tagStr := structType.Field(0).Tag.Get("sshtype")
282
283	for _, tag := range strings.Split(tagStr, "|") {
284		i, err := strconv.Atoi(tag)
285		if err == nil {
286			tags = append(tags, byte(i))
287		}
288	}
289
290	return tags
291}
292
293func fieldError(t reflect.Type, field int, problem string) error {
294	if problem != "" {
295		problem = ": " + problem
296	}
297	return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem)
298}
299
300var errShortRead = errors.New("ssh: short read")
301
302// Unmarshal parses data in SSH wire format into a structure. The out
303// argument should be a pointer to struct. If the first member of the
304// struct has the "sshtype" tag set to a '|'-separated set of numbers
305// in decimal, the packet must start with one of those numbers. In
306// case of error, Unmarshal returns a ParseError or
307// UnexpectedMessageError.
308func Unmarshal(data []byte, out interface{}) error {
309	v := reflect.ValueOf(out).Elem()
310	structType := v.Type()
311	expectedTypes := typeTags(structType)
312
313	var expectedType byte
314	if len(expectedTypes) > 0 {
315		expectedType = expectedTypes[0]
316	}
317
318	if len(data) == 0 {
319		return parseError(expectedType)
320	}
321
322	if len(expectedTypes) > 0 {
323		goodType := false
324		for _, e := range expectedTypes {
325			if e > 0 && data[0] == e {
326				goodType = true
327				break
328			}
329		}
330		if !goodType {
331			return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes)
332		}
333		data = data[1:]
334	}
335
336	var ok bool
337	for i := 0; i < v.NumField(); i++ {
338		field := v.Field(i)
339		t := field.Type()
340		switch t.Kind() {
341		case reflect.Bool:
342			if len(data) < 1 {
343				return errShortRead
344			}
345			field.SetBool(data[0] != 0)
346			data = data[1:]
347		case reflect.Array:
348			if t.Elem().Kind() != reflect.Uint8 {
349				return fieldError(structType, i, "array of unsupported type")
350			}
351			if len(data) < t.Len() {
352				return errShortRead
353			}
354			for j, n := 0, t.Len(); j < n; j++ {
355				field.Index(j).Set(reflect.ValueOf(data[j]))
356			}
357			data = data[t.Len():]
358		case reflect.Uint64:
359			var u64 uint64
360			if u64, data, ok = parseUint64(data); !ok {
361				return errShortRead
362			}
363			field.SetUint(u64)
364		case reflect.Uint32:
365			var u32 uint32
366			if u32, data, ok = parseUint32(data); !ok {
367				return errShortRead
368			}
369			field.SetUint(uint64(u32))
370		case reflect.Uint8:
371			if len(data) < 1 {
372				return errShortRead
373			}
374			field.SetUint(uint64(data[0]))
375			data = data[1:]
376		case reflect.String:
377			var s []byte
378			if s, data, ok = parseString(data); !ok {
379				return fieldError(structType, i, "")
380			}
381			field.SetString(string(s))
382		case reflect.Slice:
383			switch t.Elem().Kind() {
384			case reflect.Uint8:
385				if structType.Field(i).Tag.Get("ssh") == "rest" {
386					field.Set(reflect.ValueOf(data))
387					data = nil
388				} else {
389					var s []byte
390					if s, data, ok = parseString(data); !ok {
391						return errShortRead
392					}
393					field.Set(reflect.ValueOf(s))
394				}
395			case reflect.String:
396				var nl []string
397				if nl, data, ok = parseNameList(data); !ok {
398					return errShortRead
399				}
400				field.Set(reflect.ValueOf(nl))
401			default:
402				return fieldError(structType, i, "slice of unsupported type")
403			}
404		case reflect.Ptr:
405			if t == bigIntType {
406				var n *big.Int
407				if n, data, ok = parseInt(data); !ok {
408					return errShortRead
409				}
410				field.Set(reflect.ValueOf(n))
411			} else {
412				return fieldError(structType, i, "pointer to unsupported type")
413			}
414		default:
415			return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t))
416		}
417	}
418
419	if len(data) != 0 {
420		return parseError(expectedType)
421	}
422
423	return nil
424}
425
426// Marshal serializes the message in msg to SSH wire format.  The msg
427// argument should be a struct or pointer to struct. If the first
428// member has the "sshtype" tag set to a number in decimal, that
429// number is prepended to the result. If the last of member has the
430// "ssh" tag set to "rest", its contents are appended to the output.
431func Marshal(msg interface{}) []byte {
432	out := make([]byte, 0, 64)
433	return marshalStruct(out, msg)
434}
435
436func marshalStruct(out []byte, msg interface{}) []byte {
437	v := reflect.Indirect(reflect.ValueOf(msg))
438	msgTypes := typeTags(v.Type())
439	if len(msgTypes) > 0 {
440		out = append(out, msgTypes[0])
441	}
442
443	for i, n := 0, v.NumField(); i < n; i++ {
444		field := v.Field(i)
445		switch t := field.Type(); t.Kind() {
446		case reflect.Bool:
447			var v uint8
448			if field.Bool() {
449				v = 1
450			}
451			out = append(out, v)
452		case reflect.Array:
453			if t.Elem().Kind() != reflect.Uint8 {
454				panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface()))
455			}
456			for j, l := 0, t.Len(); j < l; j++ {
457				out = append(out, uint8(field.Index(j).Uint()))
458			}
459		case reflect.Uint32:
460			out = appendU32(out, uint32(field.Uint()))
461		case reflect.Uint64:
462			out = appendU64(out, uint64(field.Uint()))
463		case reflect.Uint8:
464			out = append(out, uint8(field.Uint()))
465		case reflect.String:
466			s := field.String()
467			out = appendInt(out, len(s))
468			out = append(out, s...)
469		case reflect.Slice:
470			switch t.Elem().Kind() {
471			case reflect.Uint8:
472				if v.Type().Field(i).Tag.Get("ssh") != "rest" {
473					out = appendInt(out, field.Len())
474				}
475				out = append(out, field.Bytes()...)
476			case reflect.String:
477				offset := len(out)
478				out = appendU32(out, 0)
479				if n := field.Len(); n > 0 {
480					for j := 0; j < n; j++ {
481						f := field.Index(j)
482						if j != 0 {
483							out = append(out, ',')
484						}
485						out = append(out, f.String()...)
486					}
487					// overwrite length value
488					binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
489				}
490			default:
491				panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface()))
492			}
493		case reflect.Ptr:
494			if t == bigIntType {
495				var n *big.Int
496				nValue := reflect.ValueOf(&n)
497				nValue.Elem().Set(field)
498				needed := intLength(n)
499				oldLength := len(out)
500
501				if cap(out)-len(out) < needed {
502					newOut := make([]byte, len(out), 2*(len(out)+needed))
503					copy(newOut, out)
504					out = newOut
505				}
506				out = out[:oldLength+needed]
507				marshalInt(out[oldLength:], n)
508			} else {
509				panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface()))
510			}
511		}
512	}
513
514	return out
515}
516
517var bigOne = big.NewInt(1)
518
519func parseString(in []byte) (out, rest []byte, ok bool) {
520	if len(in) < 4 {
521		return
522	}
523	length := binary.BigEndian.Uint32(in)
524	in = in[4:]
525	if uint32(len(in)) < length {
526		return
527	}
528	out = in[:length]
529	rest = in[length:]
530	ok = true
531	return
532}
533
534var (
535	comma         = []byte{','}
536	emptyNameList = []string{}
537)
538
539func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
540	contents, rest, ok := parseString(in)
541	if !ok {
542		return
543	}
544	if len(contents) == 0 {
545		out = emptyNameList
546		return
547	}
548	parts := bytes.Split(contents, comma)
549	out = make([]string, len(parts))
550	for i, part := range parts {
551		out[i] = string(part)
552	}
553	return
554}
555
556func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
557	contents, rest, ok := parseString(in)
558	if !ok {
559		return
560	}
561	out = new(big.Int)
562
563	if len(contents) > 0 && contents[0]&0x80 == 0x80 {
564		// This is a negative number
565		notBytes := make([]byte, len(contents))
566		for i := range notBytes {
567			notBytes[i] = ^contents[i]
568		}
569		out.SetBytes(notBytes)
570		out.Add(out, bigOne)
571		out.Neg(out)
572	} else {
573		// Positive number
574		out.SetBytes(contents)
575	}
576	ok = true
577	return
578}
579
580func parseUint32(in []byte) (uint32, []byte, bool) {
581	if len(in) < 4 {
582		return 0, nil, false
583	}
584	return binary.BigEndian.Uint32(in), in[4:], true
585}
586
587func parseUint64(in []byte) (uint64, []byte, bool) {
588	if len(in) < 8 {
589		return 0, nil, false
590	}
591	return binary.BigEndian.Uint64(in), in[8:], true
592}
593
594func intLength(n *big.Int) int {
595	length := 4 /* length bytes */
596	if n.Sign() < 0 {
597		nMinus1 := new(big.Int).Neg(n)
598		nMinus1.Sub(nMinus1, bigOne)
599		bitLen := nMinus1.BitLen()
600		if bitLen%8 == 0 {
601			// The number will need 0xff padding
602			length++
603		}
604		length += (bitLen + 7) / 8
605	} else if n.Sign() == 0 {
606		// A zero is the zero length string
607	} else {
608		bitLen := n.BitLen()
609		if bitLen%8 == 0 {
610			// The number will need 0x00 padding
611			length++
612		}
613		length += (bitLen + 7) / 8
614	}
615
616	return length
617}
618
619func marshalUint32(to []byte, n uint32) []byte {
620	binary.BigEndian.PutUint32(to, n)
621	return to[4:]
622}
623
624func marshalUint64(to []byte, n uint64) []byte {
625	binary.BigEndian.PutUint64(to, n)
626	return to[8:]
627}
628
629func marshalInt(to []byte, n *big.Int) []byte {
630	lengthBytes := to
631	to = to[4:]
632	length := 0
633
634	if n.Sign() < 0 {
635		// A negative number has to be converted to two's-complement
636		// form. So we'll subtract 1 and invert. If the
637		// most-significant-bit isn't set then we'll need to pad the
638		// beginning with 0xff in order to keep the number negative.
639		nMinus1 := new(big.Int).Neg(n)
640		nMinus1.Sub(nMinus1, bigOne)
641		bytes := nMinus1.Bytes()
642		for i := range bytes {
643			bytes[i] ^= 0xff
644		}
645		if len(bytes) == 0 || bytes[0]&0x80 == 0 {
646			to[0] = 0xff
647			to = to[1:]
648			length++
649		}
650		nBytes := copy(to, bytes)
651		to = to[nBytes:]
652		length += nBytes
653	} else if n.Sign() == 0 {
654		// A zero is the zero length string
655	} else {
656		bytes := n.Bytes()
657		if len(bytes) > 0 && bytes[0]&0x80 != 0 {
658			// We'll have to pad this with a 0x00 in order to
659			// stop it looking like a negative number.
660			to[0] = 0
661			to = to[1:]
662			length++
663		}
664		nBytes := copy(to, bytes)
665		to = to[nBytes:]
666		length += nBytes
667	}
668
669	lengthBytes[0] = byte(length >> 24)
670	lengthBytes[1] = byte(length >> 16)
671	lengthBytes[2] = byte(length >> 8)
672	lengthBytes[3] = byte(length)
673	return to
674}
675
676func writeInt(w io.Writer, n *big.Int) {
677	length := intLength(n)
678	buf := make([]byte, length)
679	marshalInt(buf, n)
680	w.Write(buf)
681}
682
683func writeString(w io.Writer, s []byte) {
684	var lengthBytes [4]byte
685	lengthBytes[0] = byte(len(s) >> 24)
686	lengthBytes[1] = byte(len(s) >> 16)
687	lengthBytes[2] = byte(len(s) >> 8)
688	lengthBytes[3] = byte(len(s))
689	w.Write(lengthBytes[:])
690	w.Write(s)
691}
692
693func stringLength(n int) int {
694	return 4 + n
695}
696
697func marshalString(to []byte, s []byte) []byte {
698	to[0] = byte(len(s) >> 24)
699	to[1] = byte(len(s) >> 16)
700	to[2] = byte(len(s) >> 8)
701	to[3] = byte(len(s))
702	to = to[4:]
703	copy(to, s)
704	return to[len(s):]
705}
706
707var bigIntType = reflect.TypeOf((*big.Int)(nil))
708
709// Decode a packet into its corresponding message.
710func decode(packet []byte) (interface{}, error) {
711	var msg interface{}
712	switch packet[0] {
713	case msgDisconnect:
714		msg = new(disconnectMsg)
715	case msgServiceRequest:
716		msg = new(serviceRequestMsg)
717	case msgServiceAccept:
718		msg = new(serviceAcceptMsg)
719	case msgKexInit:
720		msg = new(kexInitMsg)
721	case msgKexDHInit:
722		msg = new(kexDHInitMsg)
723	case msgKexDHReply:
724		msg = new(kexDHReplyMsg)
725	case msgUserAuthRequest:
726		msg = new(userAuthRequestMsg)
727	case msgUserAuthSuccess:
728		return new(userAuthSuccessMsg), nil
729	case msgUserAuthFailure:
730		msg = new(userAuthFailureMsg)
731	case msgUserAuthPubKeyOk:
732		msg = new(userAuthPubKeyOkMsg)
733	case msgGlobalRequest:
734		msg = new(globalRequestMsg)
735	case msgRequestSuccess:
736		msg = new(globalRequestSuccessMsg)
737	case msgRequestFailure:
738		msg = new(globalRequestFailureMsg)
739	case msgChannelOpen:
740		msg = new(channelOpenMsg)
741	case msgChannelData:
742		msg = new(channelDataMsg)
743	case msgChannelOpenConfirm:
744		msg = new(channelOpenConfirmMsg)
745	case msgChannelOpenFailure:
746		msg = new(channelOpenFailureMsg)
747	case msgChannelWindowAdjust:
748		msg = new(windowAdjustMsg)
749	case msgChannelEOF:
750		msg = new(channelEOFMsg)
751	case msgChannelClose:
752		msg = new(channelCloseMsg)
753	case msgChannelRequest:
754		msg = new(channelRequestMsg)
755	case msgChannelSuccess:
756		msg = new(channelRequestSuccessMsg)
757	case msgChannelFailure:
758		msg = new(channelRequestFailureMsg)
759	default:
760		return nil, unexpectedMessageError(0, packet[0])
761	}
762	if err := Unmarshal(packet, msg); err != nil {
763		return nil, err
764	}
765	return msg, nil
766}
767