1# Copyright (c) 2013-2018 by Ron Frederick <ronf@timeheart.net> and others. 2# 3# This program and the accompanying materials are made available under 4# the terms of the Eclipse Public License v2.0 which accompanies this 5# distribution and is available at: 6# 7# http://www.eclipse.org/legal/epl-2.0/ 8# 9# This program may also be made available under the following secondary 10# licenses when the conditions for such availability set forth in the 11# Eclipse Public License v2.0 are satisfied: 12# 13# GNU General Public License, Version 2.0, or any later versions of 14# that license 15# 16# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later 17# 18# Contributors: 19# Ron Frederick - initial implementation, API, and documentation 20 21"""SSH packet encoding and decoding functions""" 22 23from .misc import plural 24 25 26class PacketDecodeError(ValueError): 27 """Packet decoding error""" 28 29 30def Byte(value): 31 """Encode a single byte""" 32 33 return bytes((value,)) 34 35 36def Boolean(value): 37 """Encode a boolean value""" 38 39 return Byte(bool(value)) 40 41 42def UInt32(value): 43 """Encode a 32-bit integer value""" 44 45 return value.to_bytes(4, 'big') 46 47 48def UInt64(value): 49 """Encode a 64-bit integer value""" 50 51 return value.to_bytes(8, 'big') 52 53 54def String(value): 55 """Encode a byte string or UTF-8 string value""" 56 57 if isinstance(value, str): 58 value = value.encode('utf-8', errors='strict') 59 60 return len(value).to_bytes(4, 'big') + value 61 62 63def MPInt(value): 64 """Encode a multiple precision integer value""" 65 66 l = value.bit_length() 67 l += (l % 8 == 0 and value != 0 and value != -1 << (l - 1)) 68 l = (l + 7) // 8 69 70 return l.to_bytes(4, 'big') + value.to_bytes(l, 'big', signed=True) 71 72 73def NameList(value): 74 """Encode a comma-separated list of byte strings""" 75 76 return String(b','.join(value)) 77 78 79class SSHPacket: 80 """Decoder class for SSH packets""" 81 82 def __init__(self, packet): 83 self._packet = packet 84 self._idx = 0 85 self._len = len(packet) 86 87 def __bool__(self): 88 return self._idx != self._len 89 90 def check_end(self): 91 """Confirm that all of the data in the packet has been consumed""" 92 93 if self: 94 raise PacketDecodeError('Unexpected data at end of packet') 95 96 def get_consumed_payload(self): 97 """Return the portion of the packet consumed so far""" 98 99 return self._packet[:self._idx] 100 101 def get_remaining_payload(self): 102 """Return the portion of the packet not yet consumed""" 103 104 return self._packet[self._idx:] 105 106 def get_full_payload(self): 107 """Return the full packet""" 108 109 return self._packet 110 111 def get_bytes(self, size): 112 """Extract the requested number of bytes from the packet""" 113 114 if self._idx + size > self._len: 115 raise PacketDecodeError('Incomplete packet') 116 117 value = self._packet[self._idx:self._idx+size] 118 self._idx += size 119 return value 120 121 def get_byte(self): 122 """Extract a single byte from the packet""" 123 124 return self.get_bytes(1)[0] 125 126 def get_boolean(self): 127 """Extract a boolean from the packet""" 128 129 return bool(self.get_byte()) 130 131 def get_uint32(self): 132 """Extract a 32-bit integer from the packet""" 133 134 return int.from_bytes(self.get_bytes(4), 'big') 135 136 def get_uint64(self): 137 """Extract a 64-bit integer from the packet""" 138 139 return int.from_bytes(self.get_bytes(8), 'big') 140 141 def get_string(self): 142 """Extract a UTF-8 string from the packet""" 143 144 return self.get_bytes(self.get_uint32()) 145 146 def get_mpint(self): 147 """Extract a multiple precision integer from the packet""" 148 149 return int.from_bytes(self.get_string(), 'big', signed=True) 150 151 def get_namelist(self): 152 """Extract a comma-separated list of byte strings from the packet""" 153 154 namelist = self.get_string() 155 return namelist.split(b',') if namelist else [] 156 157 158class SSHPacketLogger: 159 """Parent class for SSH packet loggers""" 160 161 _handler_names = {} 162 163 @property 164 def logger(self): 165 """The logger to use for packet logging""" 166 167 raise NotImplementedError 168 169 def _log_packet(self, msg, pkttype, pktid, packet, note): 170 """Log a sent/received packet""" 171 172 if isinstance(packet, SSHPacket): 173 packet = packet.get_full_payload() 174 175 try: 176 name = '%s (%d)' % (self._handler_names[pkttype], pkttype) 177 except KeyError: 178 name = 'packet type %d' % pkttype 179 180 count = plural(len(packet), 'byte') 181 182 if note: 183 note = ' (%s)' % note 184 185 self.logger.packet(pktid, packet, '%s %s, %s%s', 186 msg, name, count, note) 187 188 def log_sent_packet(self, pkttype, pktid, packet, note=''): 189 """Log a sent packet""" 190 191 self._log_packet('Sent', pkttype, pktid, packet, note) 192 193 194 def log_received_packet(self, pkttype, pktid, packet, note=''): 195 """Log a received packet""" 196 197 self._log_packet('Received', pkttype, pktid, packet, note) 198 199 200class SSHPacketHandler(SSHPacketLogger): 201 """Parent class for SSH packet handlers""" 202 203 _packet_handlers = {} 204 205 @property 206 def logger(self): 207 """The logger associated with this packet handler""" 208 209 raise NotImplementedError 210 211 def process_packet(self, pkttype, pktid, packet): 212 """Log and process a received packet""" 213 214 if pkttype in self._packet_handlers: 215 self._packet_handlers[pkttype](self, pkttype, pktid, packet) 216 return True 217 else: 218 return False 219