1# Copyright (C) 2011  Jeff Forcier <jeff@bitprophet.org>
2#
3# This file is part of ssh.
4#
5# 'ssh' is free software; you can redistribute it and/or modify it under the
6# terms of the GNU Lesser General Public License as published by the Free
7# Software Foundation; either version 2.1 of the License, or (at your option)
8# any later version.
9#
10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
13# details.
14#
15# You should have received a copy of the GNU Lesser General Public License
16# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
17# 51 Franklin Street, Suite 500, Boston, MA  02110-1335  USA.
18
19"""
20Variant on L{KexGroup1 <ssh.kex_group1.KexGroup1>} where the prime "p" and
21generator "g" are provided by the server.  A bit more work is required on the
22client side, and a B{lot} more on the server side.
23"""
24
25from Crypto.Hash import SHA
26from Crypto.Util import number
27
28from ssh.common import *
29from ssh import util
30from ssh.message import Message
31from ssh.ssh_exception import SSHException
32
33
34_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \
35    _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35)
36
37
38class KexGex (object):
39
40    name = 'diffie-hellman-group-exchange-sha1'
41    min_bits = 1024
42    max_bits = 8192
43    preferred_bits = 2048
44
45    def __init__(self, transport):
46        self.transport = transport
47        self.p = None
48        self.q = None
49        self.g = None
50        self.x = None
51        self.e = None
52        self.f = None
53        self.old_style = False
54
55    def start_kex(self, _test_old_style=False):
56        if self.transport.server_mode:
57            self.transport._expect_packet(_MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD)
58            return
59        # request a bit range: we accept (min_bits) to (max_bits), but prefer
60        # (preferred_bits).  according to the spec, we shouldn't pull the
61        # minimum up above 1024.
62        m = Message()
63        if _test_old_style:
64            # only used for unit tests: we shouldn't ever send this
65            m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD))
66            m.add_int(self.preferred_bits)
67            self.old_style = True
68        else:
69            m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST))
70            m.add_int(self.min_bits)
71            m.add_int(self.preferred_bits)
72            m.add_int(self.max_bits)
73        self.transport._send_message(m)
74        self.transport._expect_packet(_MSG_KEXDH_GEX_GROUP)
75
76    def parse_next(self, ptype, m):
77        if ptype == _MSG_KEXDH_GEX_REQUEST:
78            return self._parse_kexdh_gex_request(m)
79        elif ptype == _MSG_KEXDH_GEX_GROUP:
80            return self._parse_kexdh_gex_group(m)
81        elif ptype == _MSG_KEXDH_GEX_INIT:
82            return self._parse_kexdh_gex_init(m)
83        elif ptype == _MSG_KEXDH_GEX_REPLY:
84            return self._parse_kexdh_gex_reply(m)
85        elif ptype == _MSG_KEXDH_GEX_REQUEST_OLD:
86            return self._parse_kexdh_gex_request_old(m)
87        raise SSHException('KexGex asked to handle packet type %d' % ptype)
88
89
90    ###  internals...
91
92
93    def _generate_x(self):
94        # generate an "x" (1 < x < (p-1)/2).
95        q = (self.p - 1) // 2
96        qnorm = util.deflate_long(q, 0)
97        qhbyte = ord(qnorm[0])
98        bytes = len(qnorm)
99        qmask = 0xff
100        while not (qhbyte & 0x80):
101            qhbyte <<= 1
102            qmask >>= 1
103        while True:
104            x_bytes = self.transport.rng.read(bytes)
105            x_bytes = chr(ord(x_bytes[0]) & qmask) + x_bytes[1:]
106            x = util.inflate_long(x_bytes, 1)
107            if (x > 1) and (x < q):
108                break
109        self.x = x
110
111    def _parse_kexdh_gex_request(self, m):
112        minbits = m.get_int()
113        preferredbits = m.get_int()
114        maxbits = m.get_int()
115        # smoosh the user's preferred size into our own limits
116        if preferredbits > self.max_bits:
117            preferredbits = self.max_bits
118        if preferredbits < self.min_bits:
119            preferredbits = self.min_bits
120        # fix min/max if they're inconsistent.  technically, we could just pout
121        # and hang up, but there's no harm in giving them the benefit of the
122        # doubt and just picking a bitsize for them.
123        if minbits > preferredbits:
124            minbits = preferredbits
125        if maxbits < preferredbits:
126            maxbits = preferredbits
127        # now save a copy
128        self.min_bits = minbits
129        self.preferred_bits = preferredbits
130        self.max_bits = maxbits
131        # generate prime
132        pack = self.transport._get_modulus_pack()
133        if pack is None:
134            raise SSHException('Can\'t do server-side gex with no modulus pack')
135        self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits))
136        self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
137        m = Message()
138        m.add_byte(chr(_MSG_KEXDH_GEX_GROUP))
139        m.add_mpint(self.p)
140        m.add_mpint(self.g)
141        self.transport._send_message(m)
142        self.transport._expect_packet(_MSG_KEXDH_GEX_INIT)
143
144    def _parse_kexdh_gex_request_old(self, m):
145        # same as above, but without min_bits or max_bits (used by older clients like putty)
146        self.preferred_bits = m.get_int()
147        # smoosh the user's preferred size into our own limits
148        if self.preferred_bits > self.max_bits:
149            self.preferred_bits = self.max_bits
150        if self.preferred_bits < self.min_bits:
151            self.preferred_bits = self.min_bits
152        # generate prime
153        pack = self.transport._get_modulus_pack()
154        if pack is None:
155            raise SSHException('Can\'t do server-side gex with no modulus pack')
156        self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,))
157        self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits)
158        m = Message()
159        m.add_byte(chr(_MSG_KEXDH_GEX_GROUP))
160        m.add_mpint(self.p)
161        m.add_mpint(self.g)
162        self.transport._send_message(m)
163        self.transport._expect_packet(_MSG_KEXDH_GEX_INIT)
164        self.old_style = True
165
166    def _parse_kexdh_gex_group(self, m):
167        self.p = m.get_mpint()
168        self.g = m.get_mpint()
169        # reject if p's bit length < 1024 or > 8192
170        bitlen = util.bit_length(self.p)
171        if (bitlen < 1024) or (bitlen > 8192):
172            raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen)
173        self.transport._log(DEBUG, 'Got server p (%d bits)' % bitlen)
174        self._generate_x()
175        # now compute e = g^x mod p
176        self.e = pow(self.g, self.x, self.p)
177        m = Message()
178        m.add_byte(chr(_MSG_KEXDH_GEX_INIT))
179        m.add_mpint(self.e)
180        self.transport._send_message(m)
181        self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY)
182
183    def _parse_kexdh_gex_init(self, m):
184        self.e = m.get_mpint()
185        if (self.e < 1) or (self.e > self.p - 1):
186            raise SSHException('Client kex "e" is out of range')
187        self._generate_x()
188        self.f = pow(self.g, self.x, self.p)
189        K = pow(self.e, self.x, self.p)
190        key = str(self.transport.get_server_key())
191        # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
192        hm = Message()
193        hm.add(self.transport.remote_version, self.transport.local_version,
194               self.transport.remote_kex_init, self.transport.local_kex_init,
195               key)
196        if not self.old_style:
197            hm.add_int(self.min_bits)
198        hm.add_int(self.preferred_bits)
199        if not self.old_style:
200            hm.add_int(self.max_bits)
201        hm.add_mpint(self.p)
202        hm.add_mpint(self.g)
203        hm.add_mpint(self.e)
204        hm.add_mpint(self.f)
205        hm.add_mpint(K)
206        H = SHA.new(str(hm)).digest()
207        self.transport._set_K_H(K, H)
208        # sign it
209        sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
210        # send reply
211        m = Message()
212        m.add_byte(chr(_MSG_KEXDH_GEX_REPLY))
213        m.add_string(key)
214        m.add_mpint(self.f)
215        m.add_string(str(sig))
216        self.transport._send_message(m)
217        self.transport._activate_outbound()
218
219    def _parse_kexdh_gex_reply(self, m):
220        host_key = m.get_string()
221        self.f = m.get_mpint()
222        sig = m.get_string()
223        if (self.f < 1) or (self.f > self.p - 1):
224            raise SSHException('Server kex "f" is out of range')
225        K = pow(self.f, self.x, self.p)
226        # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
227        hm = Message()
228        hm.add(self.transport.local_version, self.transport.remote_version,
229               self.transport.local_kex_init, self.transport.remote_kex_init,
230               host_key)
231        if not self.old_style:
232            hm.add_int(self.min_bits)
233        hm.add_int(self.preferred_bits)
234        if not self.old_style:
235            hm.add_int(self.max_bits)
236        hm.add_mpint(self.p)
237        hm.add_mpint(self.g)
238        hm.add_mpint(self.e)
239        hm.add_mpint(self.f)
240        hm.add_mpint(K)
241        self.transport._set_K_H(K, SHA.new(str(hm)).digest())
242        self.transport._verify_key(host_key, sig)
243        self.transport._activate_outbound()
244