1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"reflect"
9	"testing"
10)
11
12func TestFindAgreedAlgorithms(t *testing.T) {
13	initKex := func(k *kexInitMsg) {
14		if k.KexAlgos == nil {
15			k.KexAlgos = []string{"kex1"}
16		}
17		if k.ServerHostKeyAlgos == nil {
18			k.ServerHostKeyAlgos = []string{"hostkey1"}
19		}
20		if k.CiphersClientServer == nil {
21			k.CiphersClientServer = []string{"cipher1"}
22
23		}
24		if k.CiphersServerClient == nil {
25			k.CiphersServerClient = []string{"cipher1"}
26
27		}
28		if k.MACsClientServer == nil {
29			k.MACsClientServer = []string{"mac1"}
30
31		}
32		if k.MACsServerClient == nil {
33			k.MACsServerClient = []string{"mac1"}
34
35		}
36		if k.CompressionClientServer == nil {
37			k.CompressionClientServer = []string{"compression1"}
38
39		}
40		if k.CompressionServerClient == nil {
41			k.CompressionServerClient = []string{"compression1"}
42
43		}
44		if k.LanguagesClientServer == nil {
45			k.LanguagesClientServer = []string{"language1"}
46
47		}
48		if k.LanguagesServerClient == nil {
49			k.LanguagesServerClient = []string{"language1"}
50
51		}
52	}
53
54	initDirAlgs := func(a *directionAlgorithms) {
55		if a.Cipher == "" {
56			a.Cipher = "cipher1"
57		}
58		if a.MAC == "" {
59			a.MAC = "mac1"
60		}
61		if a.Compression == "" {
62			a.Compression = "compression1"
63		}
64	}
65
66	initAlgs := func(a *algorithms) {
67		if a.kex == "" {
68			a.kex = "kex1"
69		}
70		if a.hostKey == "" {
71			a.hostKey = "hostkey1"
72		}
73		initDirAlgs(&a.r)
74		initDirAlgs(&a.w)
75	}
76
77	type testcase struct {
78		name                   string
79		clientIn, serverIn     kexInitMsg
80		wantClient, wantServer algorithms
81		wantErr                bool
82	}
83
84	cases := []testcase{
85		testcase{
86			name: "standard",
87		},
88
89		testcase{
90			name: "no common hostkey",
91			serverIn: kexInitMsg{
92				ServerHostKeyAlgos: []string{"hostkey2"},
93			},
94			wantErr: true,
95		},
96
97		testcase{
98			name: "no common kex",
99			serverIn: kexInitMsg{
100				KexAlgos: []string{"kex2"},
101			},
102			wantErr: true,
103		},
104
105		testcase{
106			name: "no common cipher",
107			serverIn: kexInitMsg{
108				CiphersClientServer: []string{"cipher2"},
109			},
110			wantErr: true,
111		},
112
113		testcase{
114			name: "client decides cipher",
115			serverIn: kexInitMsg{
116				CiphersClientServer: []string{"cipher1", "cipher2"},
117				CiphersServerClient: []string{"cipher2", "cipher3"},
118			},
119			clientIn: kexInitMsg{
120				CiphersClientServer: []string{"cipher2", "cipher1"},
121				CiphersServerClient: []string{"cipher3", "cipher2"},
122			},
123			wantClient: algorithms{
124				r: directionAlgorithms{
125					Cipher: "cipher3",
126				},
127				w: directionAlgorithms{
128					Cipher: "cipher2",
129				},
130			},
131			wantServer: algorithms{
132				w: directionAlgorithms{
133					Cipher: "cipher3",
134				},
135				r: directionAlgorithms{
136					Cipher: "cipher2",
137				},
138			},
139		},
140
141		// TODO(hanwen): fix and add tests for AEAD ignoring
142		// the MACs field
143	}
144
145	for i := range cases {
146		initKex(&cases[i].clientIn)
147		initKex(&cases[i].serverIn)
148		initAlgs(&cases[i].wantClient)
149		initAlgs(&cases[i].wantServer)
150	}
151
152	for _, c := range cases {
153		t.Run(c.name, func(t *testing.T) {
154			serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
155			clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)
156
157			serverHasErr := serverErr != nil
158			clientHasErr := clientErr != nil
159			if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
160				t.Fatalf("got client/server error (%v, %v), want hasError %v",
161					clientErr, serverErr, c.wantErr)
162
163			}
164			if c.wantErr {
165				return
166			}
167
168			if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
169				t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
170			}
171			if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
172				t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
173			}
174		})
175	}
176}
177