1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""ADB protocol implementation.
15
16Implements the ADB protocol as seen in android's adb/adbd binaries, but only the
17host side.
18"""
19
20import struct
21import time
22from io import BytesIO
23from adb import usb_exceptions
24
25# Maximum amount of data in an ADB packet.
26MAX_ADB_DATA = 4096
27# ADB protocol version.
28VERSION = 0x01000000
29
30# AUTH constants for arg0.
31AUTH_TOKEN = 1
32AUTH_SIGNATURE = 2
33AUTH_RSAPUBLICKEY = 3
34
35
36def find_backspace_runs(stdout_bytes, start_pos):
37    first_backspace_pos = stdout_bytes[start_pos:].find(b'\x08')
38    if first_backspace_pos == -1:
39        return -1, 0
40
41    end_backspace_pos = (start_pos + first_backspace_pos) + 1
42    while True:
43        if chr(stdout_bytes[end_backspace_pos]) == '\b':
44            end_backspace_pos += 1
45        else:
46            break
47
48    num_backspaces = end_backspace_pos - (start_pos + first_backspace_pos)
49
50    return (start_pos + first_backspace_pos), num_backspaces
51
52
53class InvalidCommandError(Exception):
54    """Got an invalid command over USB."""
55
56    def __init__(self, message, response_header, response_data):
57        if response_header == b'FAIL':
58            message = 'Command failed, device said so. (%s)' % message
59        super(InvalidCommandError, self).__init__(
60            message, response_header, response_data)
61
62
63class InvalidResponseError(Exception):
64    """Got an invalid response to our command."""
65
66
67class InvalidChecksumError(Exception):
68    """Checksum of data didn't match expected checksum."""
69
70
71class InterleavedDataError(Exception):
72    """We only support command sent serially."""
73
74
75def MakeWireIDs(ids):
76    id_to_wire = {
77        cmd_id: sum(c << (i * 8) for i, c in enumerate(bytearray(cmd_id)))
78        for cmd_id in ids
79    }
80    wire_to_id = {wire: cmd_id for cmd_id, wire in id_to_wire.items()}
81    return id_to_wire, wire_to_id
82
83
84class AuthSigner(object):
85    """Signer for use with authenticated ADB, introduced in 4.4.x/KitKat."""
86
87    def Sign(self, data):
88        """Signs given data using a private key."""
89        raise NotImplementedError()
90
91    def GetPublicKey(self):
92        """Returns the public key in PEM format without headers or newlines."""
93        raise NotImplementedError()
94
95
96class _AdbConnection(object):
97    """ADB Connection."""
98
99    def __init__(self, usb, local_id, remote_id, timeout_ms):
100        self.usb = usb
101        self.local_id = local_id
102        self.remote_id = remote_id
103        self.timeout_ms = timeout_ms
104
105    def _Send(self, command, arg0, arg1, data=b''):
106        message = AdbMessage(command, arg0, arg1, data)
107        message.Send(self.usb, self.timeout_ms)
108
109    def Write(self, data):
110        """Write a packet and expect an Ack."""
111        self._Send(b'WRTE', arg0=self.local_id, arg1=self.remote_id, data=data)
112        # Expect an ack in response.
113        cmd, okay_data = self.ReadUntil(b'OKAY')
114        if cmd != b'OKAY':
115            if cmd == b'FAIL':
116                raise usb_exceptions.AdbCommandFailureException(
117                    'Command failed.', okay_data)
118            raise InvalidCommandError(
119                'Expected an OKAY in response to a WRITE, got %s (%s)',
120                cmd, okay_data)
121        return len(data)
122
123    def Okay(self):
124        self._Send(b'OKAY', arg0=self.local_id, arg1=self.remote_id)
125
126    def ReadUntil(self, *expected_cmds):
127        """Read a packet, Ack any write packets."""
128        cmd, remote_id, local_id, data = AdbMessage.Read(
129            self.usb, expected_cmds, self.timeout_ms)
130        if local_id != 0 and self.local_id != local_id:
131            raise InterleavedDataError("We don't support multiple streams...")
132        if remote_id != 0 and self.remote_id != remote_id:
133            raise InvalidResponseError(
134                'Incorrect remote id, expected %s got %s' % (
135                    self.remote_id, remote_id))
136        # Ack write packets.
137        if cmd == b'WRTE':
138            self.Okay()
139        return cmd, data
140
141    def ReadUntilClose(self):
142        """Yield packets until a Close packet is received."""
143        while True:
144            cmd, data = self.ReadUntil(b'CLSE', b'WRTE')
145            if cmd == b'CLSE':
146                self._Send(b'CLSE', arg0=self.local_id, arg1=self.remote_id)
147                break
148            if cmd != b'WRTE':
149                if cmd == b'FAIL':
150                    raise usb_exceptions.AdbCommandFailureException(
151                        'Command failed.', data)
152                raise InvalidCommandError('Expected a WRITE or a CLOSE, got %s (%s)',
153                                          cmd, data)
154            yield data
155
156    def Close(self):
157        self._Send(b'CLSE', arg0=self.local_id, arg1=self.remote_id)
158        cmd, data = self.ReadUntil(b'CLSE')
159        if cmd != b'CLSE':
160            if cmd == b'FAIL':
161                raise usb_exceptions.AdbCommandFailureException('Command failed.', data)
162            raise InvalidCommandError('Expected a CLSE response, got %s (%s)',
163                                      cmd, data)
164
165
166class AdbMessage(object):
167    """ADB Protocol and message class.
168
169    Protocol Notes
170
171    local_id/remote_id:
172      Turns out the documentation is host/device ambidextrous, so local_id is the
173      id for 'the sender' and remote_id is for 'the recipient'. So since we're
174      only on the host, we'll re-document with host_id and device_id:
175
176      OPEN(host_id, 0, 'shell:XXX')
177      READY/OKAY(device_id, host_id, '')
178      WRITE(0, host_id, 'data')
179      CLOSE(device_id, host_id, '')
180    """
181
182    ids = [b'SYNC', b'CNXN', b'AUTH', b'OPEN', b'OKAY', b'CLSE', b'WRTE']
183    commands, constants = MakeWireIDs(ids)
184    # An ADB message is 6 words in little-endian.
185    format = b'<6I'
186
187    connections = 0
188
189    def __init__(self, command=None, arg0=None, arg1=None, data=b''):
190        self.command = self.commands[command]
191        self.magic = self.command ^ 0xFFFFFFFF
192        self.arg0 = arg0
193        self.arg1 = arg1
194        self.data = data
195
196    @property
197    def checksum(self):
198        return self.CalculateChecksum(self.data)
199
200    @staticmethod
201    def CalculateChecksum(data):
202        # The checksum is just a sum of all the bytes. I swear.
203        if isinstance(data, bytearray):
204            total = sum(data)
205        elif isinstance(data, bytes):
206            if data and isinstance(data[0], bytes):
207                # Python 2 bytes (str) index as single-character strings.
208                total = sum(map(ord, data))
209            else:
210                # Python 3 bytes index as numbers (and PY2 empty strings sum() to 0)
211                total = sum(data)
212        else:
213            # Unicode strings (should never see?)
214            total = sum(map(ord, data))
215        return total & 0xFFFFFFFF
216
217    def Pack(self):
218        """Returns this message in an over-the-wire format."""
219        return struct.pack(self.format, self.command, self.arg0, self.arg1,
220                           len(self.data), self.checksum, self.magic)
221
222    @classmethod
223    def Unpack(cls, message):
224        try:
225            cmd, arg0, arg1, data_length, data_checksum, unused_magic = struct.unpack(
226                cls.format, message)
227        except struct.error as e:
228            raise ValueError('Unable to unpack ADB command.', cls.format, message, e)
229        return cmd, arg0, arg1, data_length, data_checksum
230
231    def Send(self, usb, timeout_ms=None):
232        """Send this message over USB."""
233        usb.BulkWrite(self.Pack(), timeout_ms)
234        usb.BulkWrite(self.data, timeout_ms)
235
236    @classmethod
237    def Read(cls, usb, expected_cmds, timeout_ms=None, total_timeout_ms=None):
238        """Receive a response from the device."""
239        total_timeout_ms = usb.Timeout(total_timeout_ms)
240        start = time.time()
241        while True:
242            msg = usb.BulkRead(24, timeout_ms)
243            cmd, arg0, arg1, data_length, data_checksum = cls.Unpack(msg)
244            command = cls.constants.get(cmd)
245            if not command:
246                raise InvalidCommandError(
247                    'Unknown command: %x' % cmd, cmd, (arg0, arg1))
248            if command in expected_cmds:
249                break
250
251            if time.time() - start > total_timeout_ms:
252                raise InvalidCommandError(
253                    'Never got one of the expected responses (%s)' % expected_cmds,
254                    cmd, (timeout_ms, total_timeout_ms))
255
256        if data_length > 0:
257            data = bytearray()
258            while data_length > 0:
259                temp = usb.BulkRead(data_length, timeout_ms)
260                if len(temp) != data_length:
261                    print(
262                        "Data_length {} does not match actual number of bytes read: {}".format(data_length, len(temp)))
263                data += temp
264
265                data_length -= len(temp)
266
267            actual_checksum = cls.CalculateChecksum(data)
268            if actual_checksum != data_checksum:
269                raise InvalidChecksumError(
270                    'Received checksum %s != %s', (actual_checksum, data_checksum))
271        else:
272            data = b''
273        return command, arg0, arg1, bytes(data)
274
275    @classmethod
276    def Connect(cls, usb, banner=b'notadb', rsa_keys=None, auth_timeout_ms=100):
277        """Establish a new connection to the device.
278
279        Args:
280          usb: A USBHandle with BulkRead and BulkWrite methods.
281          banner: A string to send as a host identifier.
282          rsa_keys: List of AuthSigner subclass instances to be used for
283              authentication. The device can either accept one of these via the Sign
284              method, or we will send the result of GetPublicKey from the first one
285              if the device doesn't accept any of them.
286          auth_timeout_ms: Timeout to wait for when sending a new public key. This
287              is only relevant when we send a new public key. The device shows a
288              dialog and this timeout is how long to wait for that dialog. If used
289              in automation, this should be low to catch such a case as a failure
290              quickly; while in interactive settings it should be high to allow
291              users to accept the dialog. We default to automation here, so it's low
292              by default.
293
294        Returns:
295          The device's reported banner. Always starts with the state (device,
296              recovery, or sideload), sometimes includes information after a : with
297              various product information.
298
299        Raises:
300          usb_exceptions.DeviceAuthError: When the device expects authentication,
301              but we weren't given any valid keys.
302          InvalidResponseError: When the device does authentication in an
303              unexpected way.
304        """
305        # In py3, convert unicode to bytes. In py2, convert str to bytes.
306        # It's later joined into a byte string, so in py2, this ends up kind of being a no-op.
307        if isinstance(banner, str):
308            banner = bytearray(banner, 'utf-8')
309
310        msg = cls(
311            command=b'CNXN', arg0=VERSION, arg1=MAX_ADB_DATA,
312            data=b'host::%s\0' % banner)
313        msg.Send(usb)
314        cmd, arg0, arg1, banner = cls.Read(usb, [b'CNXN', b'AUTH'])
315        if cmd == b'AUTH':
316            if not rsa_keys:
317                raise usb_exceptions.DeviceAuthError(
318                    'Device authentication required, no keys available.')
319            # Loop through our keys, signing the last 'banner' or token.
320            for rsa_key in rsa_keys:
321                if arg0 != AUTH_TOKEN:
322                    raise InvalidResponseError(
323                        'Unknown AUTH response: %s %s %s' % (arg0, arg1, banner))
324
325                # Do not mangle the banner property here by converting it to a string
326                signed_token = rsa_key.Sign(banner)
327                msg = cls(
328                    command=b'AUTH', arg0=AUTH_SIGNATURE, arg1=0, data=signed_token)
329                msg.Send(usb)
330                cmd, arg0, unused_arg1, banner = cls.Read(usb, [b'CNXN', b'AUTH'])
331                if cmd == b'CNXN':
332                    return banner
333            # None of the keys worked, so send a public key.
334            msg = cls(
335                command=b'AUTH', arg0=AUTH_RSAPUBLICKEY, arg1=0,
336                data=rsa_keys[0].GetPublicKey() + b'\0')
337            msg.Send(usb)
338            try:
339                cmd, arg0, unused_arg1, banner = cls.Read(
340                    usb, [b'CNXN'], timeout_ms=auth_timeout_ms)
341            except usb_exceptions.ReadFailedError as e:
342                if e.usb_error.value == -7:  # Timeout.
343                    raise usb_exceptions.DeviceAuthError(
344                        'Accept auth key on device, then retry.')
345                raise
346            # This didn't time-out, so we got a CNXN response.
347            return banner
348        return banner
349
350    @classmethod
351    def Open(cls, usb, destination, timeout_ms=None):
352        """Opens a new connection to the device via an OPEN message.
353
354        Not the same as the posix 'open' or any other google3 Open methods.
355
356        Args:
357          usb: USB device handle with BulkRead and BulkWrite methods.
358          destination: The service:command string.
359          timeout_ms: Timeout in milliseconds for USB packets.
360
361        Raises:
362          InvalidResponseError: Wrong local_id sent to us.
363          InvalidCommandError: Didn't get a ready response.
364
365        Returns:
366          The local connection id.
367        """
368        local_id = 1
369        msg = cls(
370            command=b'OPEN', arg0=local_id, arg1=0,
371            data=destination + b'\0')
372        msg.Send(usb, timeout_ms)
373        cmd, remote_id, their_local_id, _ = cls.Read(usb, [b'CLSE', b'OKAY'],
374                                                     timeout_ms=timeout_ms)
375        if local_id != their_local_id:
376            raise InvalidResponseError(
377                'Expected the local_id to be {}, got {}'.format(local_id, their_local_id))
378        if cmd == b'CLSE':
379            # Some devices seem to be sending CLSE once more after a request, this *should* handle it
380            cmd, remote_id, their_local_id, _ = cls.Read(usb, [b'CLSE', b'OKAY'],
381                                                         timeout_ms=timeout_ms)
382            # Device doesn't support this service.
383            if cmd == b'CLSE':
384                return None
385        if cmd != b'OKAY':
386            raise InvalidCommandError('Expected a ready response, got {}'.format(cmd),
387                                      cmd, (remote_id, their_local_id))
388        return _AdbConnection(usb, local_id, remote_id, timeout_ms)
389
390    @classmethod
391    def Command(cls, usb, service, command='', timeout_ms=None):
392        """One complete set of USB packets for a single command.
393
394        Sends service:command in a new connection, reading the data for the
395        response. All the data is held in memory, large responses will be slow and
396        can fill up memory.
397
398        Args:
399          usb: USB device handle with BulkRead and BulkWrite methods.
400          service: The service on the device to talk to.
401          command: The command to send to the service.
402          timeout_ms: Timeout for USB packets, in milliseconds.
403
404        Raises:
405          InterleavedDataError: Multiple streams running over usb.
406          InvalidCommandError: Got an unexpected response command.
407
408        Returns:
409          The response from the service.
410        """
411        return ''.join(cls.StreamingCommand(usb, service, command, timeout_ms))
412
413    @classmethod
414    def StreamingCommand(cls, usb, service, command='', timeout_ms=None):
415        """One complete set of USB packets for a single command.
416
417        Sends service:command in a new connection, reading the data for the
418        response. All the data is held in memory, large responses will be slow and
419        can fill up memory.
420
421        Args:
422          usb: USB device handle with BulkRead and BulkWrite methods.
423          service: The service on the device to talk to.
424          command: The command to send to the service.
425          timeout_ms: Timeout for USB packets, in milliseconds.
426
427        Raises:
428          InterleavedDataError: Multiple streams running over usb.
429          InvalidCommandError: Got an unexpected response command.
430
431        Yields:
432          The responses from the service.
433        """
434        if not isinstance(command, bytes):
435            command = command.encode('utf8')
436        connection = cls.Open(
437            usb, destination=b'%s:%s' % (service, command),
438            timeout_ms=timeout_ms)
439        for data in connection.ReadUntilClose():
440            yield data.decode('utf8')
441
442    @classmethod
443    def InteractiveShellCommand(cls, conn, cmd=None, strip_cmd=True, delim=None, strip_delim=True, clean_stdout=True):
444        """Retrieves stdout of the current InteractiveShell and sends a shell command if provided
445        TODO: Should we turn this into a yield based function so we can stream all output?
446
447        Args:
448          conn: Instance of AdbConnection
449          cmd: Optional. Command to run on the target.
450          strip_cmd: Optional (default True). Strip command name from stdout.
451          delim: Optional. Delimiter to look for in the output to know when to stop expecting more output
452          (usually the shell prompt)
453          strip_delim: Optional (default True): Strip the provided delimiter from the output
454          clean_stdout: Cleanup the stdout stream of any backspaces and the characters that were deleted by the backspace
455        Returns:
456          The stdout from the shell command.
457        """
458
459        if delim is not None and not isinstance(delim, bytes):
460            delim = delim.encode('utf-8')
461
462        # Delimiter may be shell@hammerhead:/ $
463        # The user or directory could change, making the delimiter somthing like root@hammerhead:/data/local/tmp $
464        # Handle a partial delimiter to search on and clean up
465        if delim:
466            user_pos = delim.find(b'@')
467            dir_pos = delim.rfind(b':/')
468            if user_pos != -1 and dir_pos != -1:
469                partial_delim = delim[user_pos:dir_pos + 1]  # e.g. @hammerhead:
470            else:
471                partial_delim = delim
472        else:
473            partial_delim = None
474
475        stdout = ''
476        stdout_stream = BytesIO()
477        original_cmd = ''
478
479        try:
480
481            if cmd:
482                original_cmd = str(cmd)
483                cmd += '\r'  # Required. Send a carriage return right after the cmd
484                cmd = cmd.encode('utf8')
485
486                # Send the cmd raw
487                bytes_written = conn.Write(cmd)
488
489                if delim:
490                    # Expect multiple WRTE cmds until the delim (usually terminal prompt) is detected
491
492                    data = b''
493                    while partial_delim not in data:
494                        cmd, data = conn.ReadUntil(b'WRTE')
495                        stdout_stream.write(data)
496
497                else:
498                    # Otherwise, expect only a single WRTE
499                    cmd, data = conn.ReadUntil(b'WRTE')
500
501                    # WRTE cmd from device will follow with stdout data
502                    stdout_stream.write(data)
503
504            else:
505
506                # No cmd provided means we should just expect a single line from the terminal. Use this sparingly
507                cmd, data = conn.ReadUntil(b'WRTE')
508                if cmd == b'WRTE':
509                    # WRTE cmd from device will follow with stdout data
510                    stdout_stream.write(data)
511                else:
512                    print("Unhandled cmd: {}".format(cmd))
513
514            cleaned_stdout_stream = BytesIO()
515            if clean_stdout:
516                stdout_bytes = stdout_stream.getvalue()
517
518                bsruns = {}  # Backspace runs tracking
519                next_start_pos = 0
520                last_run_pos, last_run_len = find_backspace_runs(stdout_bytes, next_start_pos)
521
522                if last_run_pos != -1 and last_run_len != 0:
523                    bsruns.update({last_run_pos: last_run_len})
524                    cleaned_stdout_stream.write(stdout_bytes[next_start_pos:(last_run_pos - last_run_len)])
525                    next_start_pos += last_run_pos + last_run_len
526
527                while last_run_pos != -1:
528                    last_run_pos, last_run_len = find_backspace_runs(stdout_bytes[next_start_pos:], next_start_pos)
529
530                    if last_run_pos != -1:
531                        bsruns.update({last_run_pos: last_run_len})
532                        cleaned_stdout_stream.write(stdout_bytes[next_start_pos:(last_run_pos - last_run_len)])
533                        next_start_pos += last_run_pos + last_run_len
534
535                cleaned_stdout_stream.write(stdout_bytes[next_start_pos:])
536
537            else:
538                cleaned_stdout_stream.write(stdout_stream.getvalue())
539
540            stdout = cleaned_stdout_stream.getvalue()
541
542            # Strip original cmd that will come back in stdout
543            if original_cmd and strip_cmd:
544                findstr = original_cmd.encode('utf-8') + b'\r\r\n'
545                pos = stdout.find(findstr)
546                while pos >= 0:
547                    stdout = stdout.replace(findstr, b'')
548                    pos = stdout.find(findstr)
549
550                if b'\r\r\n' in stdout:
551                    stdout = stdout.split(b'\r\r\n')[1]
552
553            # Strip delim if requested
554            # TODO: Handling stripping partial delims here - not a deal breaker the way we're handling it now
555            if delim and strip_delim:
556                stdout = stdout.replace(delim, b'')
557
558            stdout = stdout.rstrip()
559
560        except Exception as e:
561            print("InteractiveShell exception (most likely timeout): {}".format(e))
562
563        return stdout
564