1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Common code for ADB and Fastboot.
15
16Common usb browsing, and usb communication.
17"""
18import logging
19import platform
20import socket
21import threading
22import weakref
23import select
24
25import libusb1
26import usb1
27
28from adb import usb_exceptions
29
30DEFAULT_TIMEOUT_MS = 10000
31
32_LOG = logging.getLogger('android_usb')
33
34
35def GetInterface(setting):
36    """Get the class, subclass, and protocol for the given USB setting."""
37    return (setting.getClass(), setting.getSubClass(), setting.getProtocol())
38
39
40def InterfaceMatcher(clazz, subclass, protocol):
41    """Returns a matcher that returns the setting with the given interface."""
42    interface = (clazz, subclass, protocol)
43
44    def Matcher(device):
45        for setting in device.iterSettings():
46            if GetInterface(setting) == interface:
47                return setting
48
49    return Matcher
50
51
52class UsbHandle(object):
53    """USB communication object. Not thread-safe.
54
55    Handles reading and writing over USB with the proper endpoints, exceptions,
56    and interface claiming.
57
58    Important methods:
59      FlushBuffers()
60      BulkRead(int length)
61      BulkWrite(bytes data)
62    """
63
64    _HANDLE_CACHE = weakref.WeakValueDictionary()
65    _HANDLE_CACHE_LOCK = threading.Lock()
66
67    def __init__(self, device, setting, usb_info=None, timeout_ms=None):
68        """Initialize USB Handle.
69
70        Arguments:
71          device: libusb_device to connect to.
72          setting: libusb setting with the correct endpoints to communicate with.
73          usb_info: String describing the usb path/serial/device, for debugging.
74          timeout_ms: Timeout in milliseconds for all I/O.
75        """
76        self._setting = setting
77        self._device = device
78        self._handle = None
79
80        self._usb_info = usb_info or ''
81        self._timeout_ms = timeout_ms if timeout_ms else DEFAULT_TIMEOUT_MS
82        self._max_read_packet_len = 0
83
84    @property
85    def usb_info(self):
86        try:
87            sn = self.serial_number
88        except libusb1.USBError:
89            sn = ''
90        if sn and sn != self._usb_info:
91            return '%s %s' % (self._usb_info, sn)
92        return self._usb_info
93
94    def Open(self):
95        """Opens the USB device for this setting, and claims the interface."""
96        # Make sure we close any previous handle open to this usb device.
97        port_path = tuple(self.port_path)
98        with self._HANDLE_CACHE_LOCK:
99            old_handle = self._HANDLE_CACHE.get(port_path)
100            if old_handle is not None:
101                old_handle.Close()
102
103        self._read_endpoint = None
104        self._write_endpoint = None
105
106        for endpoint in self._setting.iterEndpoints():
107            address = endpoint.getAddress()
108            if address & libusb1.USB_ENDPOINT_DIR_MASK:
109                self._read_endpoint = address
110                self._max_read_packet_len = endpoint.getMaxPacketSize()
111            else:
112                self._write_endpoint = address
113
114        assert self._read_endpoint is not None
115        assert self._write_endpoint is not None
116
117        handle = self._device.open()
118        iface_number = self._setting.getNumber()
119        try:
120            if (platform.system() != 'Windows'
121                    and handle.kernelDriverActive(iface_number)):
122                handle.detachKernelDriver(iface_number)
123        except libusb1.USBError as e:
124            if e.value == libusb1.LIBUSB_ERROR_NOT_FOUND:
125                _LOG.warning('Kernel driver not found for interface: %s.', iface_number)
126            else:
127                raise
128        handle.claimInterface(iface_number)
129        self._handle = handle
130        self._interface_number = iface_number
131
132        with self._HANDLE_CACHE_LOCK:
133            self._HANDLE_CACHE[port_path] = self
134        # When this object is deleted, make sure it's closed.
135        weakref.ref(self, self.Close)
136
137    @property
138    def serial_number(self):
139        return self._device.getSerialNumber()
140
141    @property
142    def port_path(self):
143        return [self._device.getBusNumber()] + self._device.getPortNumberList()
144
145    def Close(self):
146        if self._handle is None:
147            return
148        try:
149            self._handle.releaseInterface(self._interface_number)
150            self._handle.close()
151        except libusb1.USBError:
152            _LOG.info('USBError while closing handle %s: ',
153                      self.usb_info, exc_info=True)
154        finally:
155            self._handle = None
156
157    def Timeout(self, timeout_ms):
158        return timeout_ms if timeout_ms is not None else self._timeout_ms
159
160    def FlushBuffers(self):
161        while True:
162            try:
163                self.BulkRead(self._max_read_packet_len, timeout_ms=10)
164            except usb_exceptions.ReadFailedError as e:
165                if e.usb_error.value == libusb1.LIBUSB_ERROR_TIMEOUT:
166                    break
167                raise
168
169    def BulkWrite(self, data, timeout_ms=None):
170        if self._handle is None:
171            raise usb_exceptions.WriteFailedError(
172                'This handle has been closed, probably due to another being opened.',
173                None)
174        try:
175            return self._handle.bulkWrite(
176                self._write_endpoint, data, timeout=self.Timeout(timeout_ms))
177        except libusb1.USBError as e:
178            raise usb_exceptions.WriteFailedError(
179                'Could not send data to %s (timeout %sms)' % (
180                    self.usb_info, self.Timeout(timeout_ms)), e)
181
182    def BulkRead(self, length, timeout_ms=None):
183        if self._handle is None:
184            raise usb_exceptions.ReadFailedError(
185                'This handle has been closed, probably due to another being opened.',
186                None)
187        try:
188            # python-libusb1 > 1.6 exposes bytearray()s now instead of bytes/str.
189            # To support older and newer versions, we ensure everything's bytearray()
190            # from here on out.
191            return bytearray(self._handle.bulkRead(
192                self._read_endpoint, length, timeout=self.Timeout(timeout_ms)))
193        except libusb1.USBError as e:
194            raise usb_exceptions.ReadFailedError(
195                'Could not receive data from %s (timeout %sms)' % (
196                    self.usb_info, self.Timeout(timeout_ms)), e)
197
198    def BulkReadAsync(self, length, timeout_ms=None):
199        # See: https://pypi.python.org/pypi/libusb1 "Asynchronous I/O" section
200        return
201
202    @classmethod
203    def PortPathMatcher(cls, port_path):
204        """Returns a device matcher for the given port path."""
205        if isinstance(port_path, str):
206            # Convert from sysfs path to port_path.
207            port_path = [int(part) for part in SYSFS_PORT_SPLIT_RE.split(port_path)]
208        return lambda device: device.port_path == port_path
209
210    @classmethod
211    def SerialMatcher(cls, serial):
212        """Returns a device matcher for the given serial."""
213        return lambda device: device.serial_number == serial
214
215    @classmethod
216    def FindAndOpen(cls, setting_matcher,
217                    port_path=None, serial=None, timeout_ms=None):
218        dev = cls.Find(
219            setting_matcher, port_path=port_path, serial=serial,
220            timeout_ms=timeout_ms)
221        dev.Open()
222        dev.FlushBuffers()
223        return dev
224
225    @classmethod
226    def Find(cls, setting_matcher, port_path=None, serial=None, timeout_ms=None):
227        """Gets the first device that matches according to the keyword args."""
228        if port_path:
229            device_matcher = cls.PortPathMatcher(port_path)
230            usb_info = port_path
231        elif serial:
232            device_matcher = cls.SerialMatcher(serial)
233            usb_info = serial
234        else:
235            device_matcher = None
236            usb_info = 'first'
237        return cls.FindFirst(setting_matcher, device_matcher,
238                             usb_info=usb_info, timeout_ms=timeout_ms)
239
240    @classmethod
241    def FindFirst(cls, setting_matcher, device_matcher=None, **kwargs):
242        """Find and return the first matching device.
243
244        Args:
245          setting_matcher: See cls.FindDevices.
246          device_matcher: See cls.FindDevices.
247          **kwargs: See cls.FindDevices.
248
249        Returns:
250          An instance of UsbHandle.
251
252        Raises:
253          DeviceNotFoundError: Raised if the device is not available.
254        """
255        try:
256            return next(cls.FindDevices(
257                setting_matcher, device_matcher=device_matcher, **kwargs))
258        except StopIteration:
259            raise usb_exceptions.DeviceNotFoundError(
260                'No device available, or it is in the wrong configuration.')
261
262    @classmethod
263    def FindDevices(cls, setting_matcher, device_matcher=None,
264                    usb_info='', timeout_ms=None):
265        """Find and yield the devices that match.
266
267        Args:
268          setting_matcher: Function that returns the setting to use given a
269            usb1.USBDevice, or None if the device doesn't have a valid setting.
270          device_matcher: Function that returns True if the given UsbHandle is
271            valid. None to match any device.
272          usb_info: Info string describing device(s).
273          timeout_ms: Default timeout of commands in milliseconds.
274
275        Yields:
276          UsbHandle instances
277        """
278        ctx = usb1.USBContext()
279        for device in ctx.getDeviceList(skip_on_error=True):
280            setting = setting_matcher(device)
281            if setting is None:
282                continue
283
284            handle = cls(device, setting, usb_info=usb_info, timeout_ms=timeout_ms)
285            if device_matcher is None or device_matcher(handle):
286                yield handle
287
288
289class TcpHandle(object):
290    """TCP connection object.
291
292       Provides same interface as UsbHandle. """
293
294    def __init__(self, serial, timeout_ms=None):
295        """Initialize the TCP Handle.
296        Arguments:
297          serial: Android device serial of the form host or host:port.
298
299        Host may be an IP address or a host name.
300        """
301        # if necessary, convert serial to a unicode string
302        if isinstance(serial, (bytes, bytearray)):
303            serial = serial.decode('utf-8')
304
305        if ':' in serial:
306            self.host, self.port = serial.split(':')
307        else:
308            self.host = serial
309            self.port = 5555
310
311        self._connection = None
312        self._serial_number = '%s:%s' % (self.host, self.port)
313        self._timeout_ms = float(timeout_ms) if timeout_ms else None
314
315        self._connect()
316
317    def _connect(self):
318        timeout = self.TimeoutSeconds(self._timeout_ms)
319        self._connection = socket.create_connection((self.host, self.port),
320                                                    timeout=timeout)
321        if timeout:
322            self._connection.setblocking(0)
323
324    @property
325    def serial_number(self):
326        return self._serial_number
327
328    def BulkWrite(self, data, timeout=None):
329        t = self.TimeoutSeconds(timeout)
330        _, writeable, _ = select.select([], [self._connection], [], t)
331        if writeable:
332            return self._connection.send(data)
333        msg = 'Sending data to {} timed out after {}s. No data was sent.'.format(
334            self.serial_number, t)
335        raise usb_exceptions.TcpTimeoutException(msg)
336
337    def BulkRead(self, numbytes, timeout=None):
338        t = self.TimeoutSeconds(timeout)
339        readable, _, _ = select.select([self._connection], [], [], t)
340        if readable:
341            return self._connection.recv(numbytes)
342        msg = 'Reading from {} timed out (Timeout {}s)'.format(
343            self._serial_number, t)
344        raise usb_exceptions.TcpTimeoutException(msg)
345
346    def Timeout(self, timeout_ms):
347        return float(timeout_ms) if timeout_ms is not None else self._timeout_ms
348
349    def TimeoutSeconds(self, timeout_ms):
350        timeout = self.Timeout(timeout_ms)
351        return timeout / 1000.0 if timeout is not None else timeout
352
353    def Close(self):
354        return self._connection.close()
355