1# Copyright (c) 2013-2018 by Ron Frederick <ronf@timeheart.net> and others.
2#
3# This program and the accompanying materials are made available under
4# the terms of the Eclipse Public License v2.0 which accompanies this
5# distribution and is available at:
6#
7#     http://www.eclipse.org/legal/epl-2.0/
8#
9# This program may also be made available under the following secondary
10# licenses when the conditions for such availability set forth in the
11# Eclipse Public License v2.0 are satisfied:
12#
13#    GNU General Public License, Version 2.0, or any later versions of
14#    that license
15#
16# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
17#
18# Contributors:
19#     Ron Frederick - initial implementation, API, and documentation
20
21"""SSH packet encoding and decoding functions"""
22
23from .misc import plural
24
25
26class PacketDecodeError(ValueError):
27    """Packet decoding error"""
28
29
30def Byte(value):
31    """Encode a single byte"""
32
33    return bytes((value,))
34
35
36def Boolean(value):
37    """Encode a boolean value"""
38
39    return Byte(bool(value))
40
41
42def UInt32(value):
43    """Encode a 32-bit integer value"""
44
45    return value.to_bytes(4, 'big')
46
47
48def UInt64(value):
49    """Encode a 64-bit integer value"""
50
51    return value.to_bytes(8, 'big')
52
53
54def String(value):
55    """Encode a byte string or UTF-8 string value"""
56
57    if isinstance(value, str):
58        value = value.encode('utf-8', errors='strict')
59
60    return len(value).to_bytes(4, 'big') + value
61
62
63def MPInt(value):
64    """Encode a multiple precision integer value"""
65
66    l = value.bit_length()
67    l += (l % 8 == 0 and value != 0 and value != -1 << (l - 1))
68    l = (l + 7) // 8
69
70    return l.to_bytes(4, 'big') + value.to_bytes(l, 'big', signed=True)
71
72
73def NameList(value):
74    """Encode a comma-separated list of byte strings"""
75
76    return String(b','.join(value))
77
78
79class SSHPacket:
80    """Decoder class for SSH packets"""
81
82    def __init__(self, packet):
83        self._packet = packet
84        self._idx = 0
85        self._len = len(packet)
86
87    def __bool__(self):
88        return self._idx != self._len
89
90    def check_end(self):
91        """Confirm that all of the data in the packet has been consumed"""
92
93        if self:
94            raise PacketDecodeError('Unexpected data at end of packet')
95
96    def get_consumed_payload(self):
97        """Return the portion of the packet consumed so far"""
98
99        return self._packet[:self._idx]
100
101    def get_remaining_payload(self):
102        """Return the portion of the packet not yet consumed"""
103
104        return self._packet[self._idx:]
105
106    def get_full_payload(self):
107        """Return the full packet"""
108
109        return self._packet
110
111    def get_bytes(self, size):
112        """Extract the requested number of bytes from the packet"""
113
114        if self._idx + size > self._len:
115            raise PacketDecodeError('Incomplete packet')
116
117        value = self._packet[self._idx:self._idx+size]
118        self._idx += size
119        return value
120
121    def get_byte(self):
122        """Extract a single byte from the packet"""
123
124        return self.get_bytes(1)[0]
125
126    def get_boolean(self):
127        """Extract a boolean from the packet"""
128
129        return bool(self.get_byte())
130
131    def get_uint32(self):
132        """Extract a 32-bit integer from the packet"""
133
134        return int.from_bytes(self.get_bytes(4), 'big')
135
136    def get_uint64(self):
137        """Extract a 64-bit integer from the packet"""
138
139        return int.from_bytes(self.get_bytes(8), 'big')
140
141    def get_string(self):
142        """Extract a UTF-8 string from the packet"""
143
144        return self.get_bytes(self.get_uint32())
145
146    def get_mpint(self):
147        """Extract a multiple precision integer from the packet"""
148
149        return int.from_bytes(self.get_string(), 'big', signed=True)
150
151    def get_namelist(self):
152        """Extract a comma-separated list of byte strings from the packet"""
153
154        namelist = self.get_string()
155        return namelist.split(b',') if namelist else []
156
157
158class SSHPacketLogger:
159    """Parent class for SSH packet loggers"""
160
161    _handler_names = {}
162
163    @property
164    def logger(self):
165        """The logger to use for packet logging"""
166
167        raise NotImplementedError
168
169    def _log_packet(self, msg, pkttype, pktid, packet, note):
170        """Log a sent/received packet"""
171
172        if isinstance(packet, SSHPacket):
173            packet = packet.get_full_payload()
174
175        try:
176            name = '%s (%d)' % (self._handler_names[pkttype], pkttype)
177        except KeyError:
178            name = 'packet type %d' % pkttype
179
180        count = plural(len(packet), 'byte')
181
182        if note:
183            note = ' (%s)' % note
184
185        self.logger.packet(pktid, packet, '%s %s, %s%s',
186                           msg, name, count, note)
187
188    def log_sent_packet(self, pkttype, pktid, packet, note=''):
189        """Log a sent packet"""
190
191        self._log_packet('Sent', pkttype, pktid, packet, note)
192
193
194    def log_received_packet(self, pkttype, pktid, packet, note=''):
195        """Log a received packet"""
196
197        self._log_packet('Received', pkttype, pktid, packet, note)
198
199
200class SSHPacketHandler(SSHPacketLogger):
201    """Parent class for SSH packet handlers"""
202
203    _packet_handlers = {}
204
205    @property
206    def logger(self):
207        """The logger associated with this packet handler"""
208
209        raise NotImplementedError
210
211    def process_packet(self, pkttype, pktid, packet):
212        """Log and process a received packet"""
213
214        if pkttype in self._packet_handlers:
215            self._packet_handlers[pkttype](self, pkttype, pktid, packet)
216            return True
217        else:
218            return False
219