1# Copyright (C) 2011  Jeff Forcier <jeff@bitprophet.org>
2#
3# This file is part of ssh.
4#
5# 'ssh' is free software; you can redistribute it and/or modify it under the
6# terms of the GNU Lesser General Public License as published by the Free
7# Software Foundation; either version 2.1 of the License, or (at your option)
8# any later version.
9#
10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
13# details.
14#
15# You should have received a copy of the GNU Lesser General Public License
16# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
17# 51 Franklin Street, Suite 500, Boston, MA  02110-1335  USA.
18
19"""
20Implementation of an SSH2 "message".
21"""
22
23import struct
24import cStringIO
25
26from ssh import util
27
28
29class Message (object):
30    """
31    An SSH2 I{Message} is a stream of bytes that encodes some combination of
32    strings, integers, bools, and infinite-precision integers (known in python
33    as I{long}s).  This class builds or breaks down such a byte stream.
34
35    Normally you don't need to deal with anything this low-level, but it's
36    exposed for people implementing custom extensions, or features that
37    ssh doesn't support yet.
38    """
39
40    def __init__(self, content=None):
41        """
42        Create a new SSH2 Message.
43
44        @param content: the byte stream to use as the Message content (passed
45            in only when decomposing a Message).
46        @type content: string
47        """
48        if content != None:
49            self.packet = cStringIO.StringIO(content)
50        else:
51            self.packet = cStringIO.StringIO()
52
53    def __str__(self):
54        """
55        Return the byte stream content of this Message, as a string.
56
57        @return: the contents of this Message.
58        @rtype: string
59        """
60        return self.packet.getvalue()
61
62    def __repr__(self):
63        """
64        Returns a string representation of this object, for debugging.
65
66        @rtype: string
67        """
68        return 'ssh.Message(' + repr(self.packet.getvalue()) + ')'
69
70    def rewind(self):
71        """
72        Rewind the message to the beginning as if no items had been parsed
73        out of it yet.
74        """
75        self.packet.seek(0)
76
77    def get_remainder(self):
78        """
79        Return the bytes of this Message that haven't already been parsed and
80        returned.
81
82        @return: a string of the bytes not parsed yet.
83        @rtype: string
84        """
85        position = self.packet.tell()
86        remainder = self.packet.read()
87        self.packet.seek(position)
88        return remainder
89
90    def get_so_far(self):
91        """
92        Returns the bytes of this Message that have been parsed and returned.
93        The string passed into a Message's constructor can be regenerated by
94        concatenating C{get_so_far} and L{get_remainder}.
95
96        @return: a string of the bytes parsed so far.
97        @rtype: string
98        """
99        position = self.packet.tell()
100        self.rewind()
101        return self.packet.read(position)
102
103    def get_bytes(self, n):
104        """
105        Return the next C{n} bytes of the Message, without decomposing into
106        an int, string, etc.  Just the raw bytes are returned.
107
108        @return: a string of the next C{n} bytes of the Message, or a string
109            of C{n} zero bytes, if there aren't C{n} bytes remaining.
110        @rtype: string
111        """
112        b = self.packet.read(n)
113        if len(b) < n:
114            return b + '\x00' * (n - len(b))
115        return b
116
117    def get_byte(self):
118        """
119        Return the next byte of the Message, without decomposing it.  This
120        is equivalent to L{get_bytes(1)<get_bytes>}.
121
122        @return: the next byte of the Message, or C{'\000'} if there aren't
123            any bytes remaining.
124        @rtype: string
125        """
126        return self.get_bytes(1)
127
128    def get_boolean(self):
129        """
130        Fetch a boolean from the stream.
131
132        @return: C{True} or C{False} (from the Message).
133        @rtype: bool
134        """
135        b = self.get_bytes(1)
136        return b != '\x00'
137
138    def get_int(self):
139        """
140        Fetch an int from the stream.
141
142        @return: a 32-bit unsigned integer.
143        @rtype: int
144        """
145        return struct.unpack('>I', self.get_bytes(4))[0]
146
147    def get_int64(self):
148        """
149        Fetch a 64-bit int from the stream.
150
151        @return: a 64-bit unsigned integer.
152        @rtype: long
153        """
154        return struct.unpack('>Q', self.get_bytes(8))[0]
155
156    def get_mpint(self):
157        """
158        Fetch a long int (mpint) from the stream.
159
160        @return: an arbitrary-length integer.
161        @rtype: long
162        """
163        return util.inflate_long(self.get_string())
164
165    def get_string(self):
166        """
167        Fetch a string from the stream.  This could be a byte string and may
168        contain unprintable characters.  (It's not unheard of for a string to
169        contain another byte-stream Message.)
170
171        @return: a string.
172        @rtype: string
173        """
174        return self.get_bytes(self.get_int())
175
176    def get_list(self):
177        """
178        Fetch a list of strings from the stream.  These are trivially encoded
179        as comma-separated values in a string.
180
181        @return: a list of strings.
182        @rtype: list of strings
183        """
184        return self.get_string().split(',')
185
186    def add_bytes(self, b):
187        """
188        Write bytes to the stream, without any formatting.
189
190        @param b: bytes to add
191        @type b: str
192        """
193        self.packet.write(b)
194        return self
195
196    def add_byte(self, b):
197        """
198        Write a single byte to the stream, without any formatting.
199
200        @param b: byte to add
201        @type b: str
202        """
203        self.packet.write(b)
204        return self
205
206    def add_boolean(self, b):
207        """
208        Add a boolean value to the stream.
209
210        @param b: boolean value to add
211        @type b: bool
212        """
213        if b:
214            self.add_byte('\x01')
215        else:
216            self.add_byte('\x00')
217        return self
218
219    def add_int(self, n):
220        """
221        Add an integer to the stream.
222
223        @param n: integer to add
224        @type n: int
225        """
226        self.packet.write(struct.pack('>I', n))
227        return self
228
229    def add_int64(self, n):
230        """
231        Add a 64-bit int to the stream.
232
233        @param n: long int to add
234        @type n: long
235        """
236        self.packet.write(struct.pack('>Q', n))
237        return self
238
239    def add_mpint(self, z):
240        """
241        Add a long int to the stream, encoded as an infinite-precision
242        integer.  This method only works on positive numbers.
243
244        @param z: long int to add
245        @type z: long
246        """
247        self.add_string(util.deflate_long(z))
248        return self
249
250    def add_string(self, s):
251        """
252        Add a string to the stream.
253
254        @param s: string to add
255        @type s: str
256        """
257        self.add_int(len(s))
258        self.packet.write(s)
259        return self
260
261    def add_list(self, l):
262        """
263        Add a list of strings to the stream.  They are encoded identically to
264        a single string of values separated by commas.  (Yes, really, that's
265        how SSH2 does it.)
266
267        @param l: list of strings to add
268        @type l: list(str)
269        """
270        self.add_string(','.join(l))
271        return self
272
273    def _add(self, i):
274        if type(i) is str:
275            return self.add_string(i)
276        elif type(i) is int:
277            return self.add_int(i)
278        elif type(i) is long:
279            if i > 0xffffffffL:
280                return self.add_mpint(i)
281            else:
282                return self.add_int(i)
283        elif type(i) is bool:
284            return self.add_boolean(i)
285        elif type(i) is list:
286            return self.add_list(i)
287        else:
288            raise Exception('Unknown type')
289
290    def add(self, *seq):
291        """
292        Add a sequence of items to the stream.  The values are encoded based
293        on their type: str, int, bool, list, or long.
294
295        @param seq: the sequence of items
296        @type seq: sequence
297
298        @bug: longs are encoded non-deterministically.  Don't use this method.
299        """
300        for item in seq:
301            self._add(item)
302