1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"bytes"
9	"math/big"
10	"math/rand"
11	"reflect"
12	"testing"
13	"testing/quick"
14)
15
16var intLengthTests = []struct {
17	val, length int
18}{
19	{0, 4 + 0},
20	{1, 4 + 1},
21	{127, 4 + 1},
22	{128, 4 + 2},
23	{-1, 4 + 1},
24}
25
26func TestIntLength(t *testing.T) {
27	for _, test := range intLengthTests {
28		v := new(big.Int).SetInt64(int64(test.val))
29		length := intLength(v)
30		if length != test.length {
31			t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
32		}
33	}
34}
35
36type msgAllTypes struct {
37	Bool    bool `sshtype:"21"`
38	Array   [16]byte
39	Uint64  uint64
40	Uint32  uint32
41	Uint8   uint8
42	String  string
43	Strings []string
44	Bytes   []byte
45	Int     *big.Int
46	Rest    []byte `ssh:"rest"`
47}
48
49func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
50	m := &msgAllTypes{}
51	m.Bool = rand.Intn(2) == 1
52	randomBytes(m.Array[:], rand)
53	m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
54	m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
55	m.Uint8 = uint8(rand.Intn(1 << 8))
56	m.String = string(m.Array[:])
57	m.Strings = randomNameList(rand)
58	m.Bytes = m.Array[:]
59	m.Int = randomInt(rand)
60	m.Rest = m.Array[:]
61	return reflect.ValueOf(m)
62}
63
64func TestMarshalUnmarshal(t *testing.T) {
65	rand := rand.New(rand.NewSource(0))
66	iface := &msgAllTypes{}
67	ty := reflect.ValueOf(iface).Type()
68
69	n := 100
70	if testing.Short() {
71		n = 5
72	}
73	for j := 0; j < n; j++ {
74		v, ok := quick.Value(ty, rand)
75		if !ok {
76			t.Errorf("failed to create value")
77			break
78		}
79
80		m1 := v.Elem().Interface()
81		m2 := iface
82
83		marshaled := Marshal(m1)
84		if err := Unmarshal(marshaled, m2); err != nil {
85			t.Errorf("Unmarshal %#v: %s", m1, err)
86			break
87		}
88
89		if !reflect.DeepEqual(v.Interface(), m2) {
90			t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
91			break
92		}
93	}
94}
95
96func TestUnmarshalEmptyPacket(t *testing.T) {
97	var b []byte
98	var m channelRequestSuccessMsg
99	if err := Unmarshal(b, &m); err == nil {
100		t.Fatalf("unmarshal of empty slice succeeded")
101	}
102}
103
104func TestUnmarshalUnexpectedPacket(t *testing.T) {
105	type S struct {
106		I uint32 `sshtype:"43"`
107		S string
108		B bool
109	}
110
111	s := S{11, "hello", true}
112	packet := Marshal(s)
113	packet[0] = 42
114	roundtrip := S{}
115	err := Unmarshal(packet, &roundtrip)
116	if err == nil {
117		t.Fatal("expected error, not nil")
118	}
119}
120
121func TestMarshalPtr(t *testing.T) {
122	s := struct {
123		S string
124	}{"hello"}
125
126	m1 := Marshal(s)
127	m2 := Marshal(&s)
128	if !bytes.Equal(m1, m2) {
129		t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
130	}
131}
132
133func TestBareMarshalUnmarshal(t *testing.T) {
134	type S struct {
135		I uint32
136		S string
137		B bool
138	}
139
140	s := S{42, "hello", true}
141	packet := Marshal(s)
142	roundtrip := S{}
143	Unmarshal(packet, &roundtrip)
144
145	if !reflect.DeepEqual(s, roundtrip) {
146		t.Errorf("got %#v, want %#v", roundtrip, s)
147	}
148}
149
150func TestBareMarshal(t *testing.T) {
151	type S2 struct {
152		I uint32
153	}
154	s := S2{42}
155	packet := Marshal(s)
156	i, rest, ok := parseUint32(packet)
157	if len(rest) > 0 || !ok {
158		t.Errorf("parseInt(%q): parse error", packet)
159	}
160	if i != s.I {
161		t.Errorf("got %d, want %d", i, s.I)
162	}
163}
164
165func TestUnmarshalShortKexInitPacket(t *testing.T) {
166	// This used to panic.
167	// Issue 11348
168	packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff}
169	kim := &kexInitMsg{}
170	if err := Unmarshal(packet, kim); err == nil {
171		t.Error("truncated packet unmarshaled without error")
172	}
173}
174
175func TestMarshalMultiTag(t *testing.T) {
176	var res struct {
177		A uint32 `sshtype:"1|2"`
178	}
179
180	good1 := struct {
181		A uint32 `sshtype:"1"`
182	}{
183		1,
184	}
185	good2 := struct {
186		A uint32 `sshtype:"2"`
187	}{
188		1,
189	}
190
191	if e := Unmarshal(Marshal(good1), &res); e != nil {
192		t.Errorf("error unmarshaling multipart tag: %v", e)
193	}
194
195	if e := Unmarshal(Marshal(good2), &res); e != nil {
196		t.Errorf("error unmarshaling multipart tag: %v", e)
197	}
198
199	bad1 := struct {
200		A uint32 `sshtype:"3"`
201	}{
202		1,
203	}
204	if e := Unmarshal(Marshal(bad1), &res); e == nil {
205		t.Errorf("bad struct unmarshaled without error")
206	}
207}
208
209func randomBytes(out []byte, rand *rand.Rand) {
210	for i := 0; i < len(out); i++ {
211		out[i] = byte(rand.Int31())
212	}
213}
214
215func randomNameList(rand *rand.Rand) []string {
216	ret := make([]string, rand.Int31()&15)
217	for i := range ret {
218		s := make([]byte, 1+(rand.Int31()&15))
219		for j := range s {
220			s[j] = 'a' + uint8(rand.Int31()&15)
221		}
222		ret[i] = string(s)
223	}
224	return ret
225}
226
227func randomInt(rand *rand.Rand) *big.Int {
228	return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
229}
230
231func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
232	ki := &kexInitMsg{}
233	randomBytes(ki.Cookie[:], rand)
234	ki.KexAlgos = randomNameList(rand)
235	ki.ServerHostKeyAlgos = randomNameList(rand)
236	ki.CiphersClientServer = randomNameList(rand)
237	ki.CiphersServerClient = randomNameList(rand)
238	ki.MACsClientServer = randomNameList(rand)
239	ki.MACsServerClient = randomNameList(rand)
240	ki.CompressionClientServer = randomNameList(rand)
241	ki.CompressionServerClient = randomNameList(rand)
242	ki.LanguagesClientServer = randomNameList(rand)
243	ki.LanguagesServerClient = randomNameList(rand)
244	if rand.Int31()&1 == 1 {
245		ki.FirstKexFollows = true
246	}
247	return reflect.ValueOf(ki)
248}
249
250func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
251	dhi := &kexDHInitMsg{}
252	dhi.X = randomInt(rand)
253	return reflect.ValueOf(dhi)
254}
255
256var (
257	_kexInitMsg   = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
258	_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
259
260	_kexInit   = Marshal(_kexInitMsg)
261	_kexDHInit = Marshal(_kexDHInitMsg)
262)
263
264func BenchmarkMarshalKexInitMsg(b *testing.B) {
265	for i := 0; i < b.N; i++ {
266		Marshal(_kexInitMsg)
267	}
268}
269
270func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
271	m := new(kexInitMsg)
272	for i := 0; i < b.N; i++ {
273		Unmarshal(_kexInit, m)
274	}
275}
276
277func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
278	for i := 0; i < b.N; i++ {
279		Marshal(_kexDHInitMsg)
280	}
281}
282
283func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
284	m := new(kexDHInitMsg)
285	for i := 0; i < b.N; i++ {
286		Unmarshal(_kexDHInit, m)
287	}
288}
289