1# Copyright 2016 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""HID Transport for U2F.
16
17This module imports the U2F HID Transport protocol as well as methods
18for discovering devices implementing this protocol.
19"""
20import logging
21import os
22import struct
23import time
24
25from pyu2f import errors
26from pyu2f import hid
27
28
29def HidUsageSelector(device):
30  if device['usage_page'] == 0xf1d0 and device['usage'] == 0x01:
31    return True
32  return False
33
34
35def DiscoverLocalHIDU2FDevices(selector=HidUsageSelector):
36  for d in hid.Enumerate():
37    if selector(d):
38      try:
39        dev = hid.Open(d['path'])
40        yield UsbHidTransport(dev)
41      except OSError:
42        # Insufficient permissions to access device
43        pass
44
45
46class UsbHidTransport(object):
47  """Implements the U2FHID transport protocol.
48
49  This class implements the U2FHID transport protocol from the
50  FIDO U2F specs.  This protocol manages fragmenting longer messages
51  over a short hid frame (usually 64 bytes).  It exposes an APDU
52  channel through the MSG command as well as a series of other commands
53  for configuring and interacting with the device.
54  """
55
56  U2FHID_PING = 0x81
57  U2FHID_MSG = 0x83
58  U2FHID_WINK = 0x88
59  U2FHID_PROMPT = 0x87
60  U2FHID_INIT = 0x86
61  U2FHID_LOCK = 0x84
62  U2FHID_ERROR = 0xbf
63  U2FHID_SYNC = 0xbc
64
65  U2FHID_BROADCAST_CID = bytearray([0xff, 0xff, 0xff, 0xff])
66
67  ERR_CHANNEL_BUSY = bytearray([0x06])
68
69  class InitPacket(object):
70    """Represent an initial U2FHID packet.
71
72    Represent an initial U2FHID packet.  This packet contains
73    metadata necessary to interpret the entire packet stream associated
74    with a particular exchange (read or write).
75
76    Attributes:
77      packet_size: The size of the hid report (packet) used.  Usually 64.
78      cid: The channel id for the connection to the device.
79      size: The size of the entire message to be sent (including
80          all continuation packets)
81      payload: The portion of the message to put into the init packet.
82          This must be smaller than packet_size - 7 (the overhead for
83          an init packet).
84    """
85
86    def __init__(self, packet_size, cid, cmd, size, payload):
87      self.packet_size = packet_size
88      if len(cid) != 4 or cmd > 255 or size >= 2**16:
89        raise errors.InvalidPacketError()
90      if len(payload) > self.packet_size - 7:
91        raise errors.InvalidPacketError()
92
93      self.cid = cid  # byte array
94      self.cmd = cmd  # number
95      self.size = size  # number (full size of message)
96      self.payload = payload  # byte array (for first packet)
97
98    def ToWireFormat(self):
99      """Serializes the packet."""
100      ret = bytearray(64)
101      ret[0:4] = self.cid
102      ret[4] = self.cmd
103      struct.pack_into('>H', ret, 5, self.size)
104      ret[7:7 + len(self.payload)] = self.payload
105      return list(map(int, ret))
106
107    @staticmethod
108    def FromWireFormat(packet_size, data):
109      """Derializes the packet.
110
111      Deserializes the packet from wire format.
112
113      Args:
114        packet_size: The size of all packets (usually 64)
115        data: List of ints or bytearray containing the data from the wire.
116
117      Returns:
118        InitPacket object for specified data
119
120      Raises:
121        InvalidPacketError: if the data isn't a valid InitPacket
122      """
123      ba = bytearray(data)
124      if len(ba) != packet_size:
125        raise errors.InvalidPacketError()
126      cid = ba[0:4]
127      cmd = ba[4]
128      size = struct.unpack('>H', bytes(ba[5:7]))[0]
129      payload = ba[7:7 + size]  # might truncate at packet_size
130      return UsbHidTransport.InitPacket(packet_size, cid, cmd, size, payload)
131
132  class ContPacket(object):
133    """Represents a continutation U2FHID packet.
134
135    Represents a continutation U2FHID packet.  These packets follow
136    the intial packet and contains the remaining data in a particular
137    message.
138
139    Attributes:
140      packet_size: The size of the hid report (packet) used.  Usually 64.
141      cid: The channel id for the connection to the device.
142      seq: The sequence number for this continuation packet.  The first
143          continuation packet is 0 and it increases from there.
144      payload:  The payload to put into this continuation packet.  This
145          must be less than packet_size - 5 (the overhead of the
146          continuation packet is 5).
147    """
148
149    def __init__(self, packet_size, cid, seq, payload):
150      self.packet_size = packet_size
151      self.cid = cid
152      self.seq = seq
153      self.payload = payload
154      if len(payload) > self.packet_size - 5:
155        raise errors.InvalidPacketError()
156      if seq > 127:
157        raise errors.InvalidPacketError()
158
159    def ToWireFormat(self):
160      """Serializes the packet."""
161      ret = bytearray(self.packet_size)
162      ret[0:4] = self.cid
163      ret[4] = self.seq
164      ret[5:5 + len(self.payload)] = self.payload
165      return list(map(int, ret))
166
167    @staticmethod
168    def FromWireFormat(packet_size, data):
169      """Derializes the packet.
170
171      Deserializes the packet from wire format.
172
173      Args:
174        packet_size: The size of all packets (usually 64)
175        data: List of ints or bytearray containing the data from the wire.
176
177      Returns:
178        InitPacket object for specified data
179
180      Raises:
181        InvalidPacketError: if the data isn't a valid ContPacket
182      """
183      ba = bytearray(data)
184      if len(ba) != packet_size:
185        raise errors.InvalidPacketError()
186      cid = ba[0:4]
187      seq = ba[4]
188      # We don't know the size limit a priori here without seeing the init
189      # packet, so truncation needs to be done in the higher level protocol
190      # handling code, unlike the degenerate case of a 1 packet message in an
191      # init packet, where the size is known.
192      payload = ba[5:]
193      return UsbHidTransport.ContPacket(packet_size, cid, seq, payload)
194
195  def __init__(self, hid_device, read_timeout_secs=3.0):
196    self.hid_device = hid_device
197
198    in_size = hid_device.GetInReportDataLength()
199    out_size = hid_device.GetOutReportDataLength()
200    if in_size != out_size:
201      raise errors.HardwareError(
202          'unsupported device with different in/out packet sizes.')
203    if in_size == 0:
204      raise errors.HardwareError('unable to determine packet size')
205
206    self.packet_size = in_size
207    self.read_timeout_secs = read_timeout_secs
208    self.logger = logging.getLogger('pyu2f.hidtransport')
209
210    self.InternalInit()
211
212  def SendMsgBytes(self, msg):
213    r = self.InternalExchange(UsbHidTransport.U2FHID_MSG, msg)
214    return r
215
216  def SendBlink(self, length):
217    return self.InternalExchange(UsbHidTransport.U2FHID_PROMPT,
218                                 bytearray([length]))
219
220  def SendWink(self):
221    return self.InternalExchange(UsbHidTransport.U2FHID_WINK, bytearray([]))
222
223  def SendPing(self, data):
224    return self.InternalExchange(UsbHidTransport.U2FHID_PING, data)
225
226  def InternalInit(self):
227    """Initializes the device and obtains channel id."""
228    self.cid = UsbHidTransport.U2FHID_BROADCAST_CID
229    nonce = bytearray(os.urandom(8))
230    r = self.InternalExchange(UsbHidTransport.U2FHID_INIT, nonce)
231    if len(r) < 17:
232      raise errors.HidError('unexpected init reply len')
233    if r[0:8] != nonce:
234      raise errors.HidError('nonce mismatch')
235    self.cid = bytearray(r[8:12])
236
237    self.u2fhid_version = r[12]
238
239  def InternalExchange(self, cmd, payload_in):
240    """Sends and receives a message from the device."""
241    # make a copy because we destroy it below
242    self.logger.debug('payload: ' + str(list(payload_in)))
243    payload = bytearray()
244    payload[:] = payload_in
245    for _ in range(2):
246      self.InternalSend(cmd, payload)
247      ret_cmd, ret_payload = self.InternalRecv()
248
249      if ret_cmd == UsbHidTransport.U2FHID_ERROR:
250        if ret_payload == UsbHidTransport.ERR_CHANNEL_BUSY:
251          time.sleep(0.5)
252          continue
253        raise errors.HidError('Device error: %d' % int(ret_payload[0]))
254      elif ret_cmd != cmd:
255        raise errors.HidError('Command mismatch!')
256
257      return ret_payload
258    raise errors.HidError('Device Busy.  Please retry')
259
260  def InternalSend(self, cmd, payload):
261    """Sends a message to the device, including fragmenting it."""
262    length_to_send = len(payload)
263
264    max_payload = self.packet_size - 7
265    first_frame = payload[0:max_payload]
266    first_packet = UsbHidTransport.InitPacket(self.packet_size, self.cid, cmd,
267                                              len(payload), first_frame)
268    del payload[0:max_payload]
269    length_to_send -= len(first_frame)
270    self.InternalSendPacket(first_packet)
271
272    seq = 0
273    while length_to_send > 0:
274      max_payload = self.packet_size - 5
275      next_frame = payload[0:max_payload]
276      del payload[0:max_payload]
277      length_to_send -= len(next_frame)
278      next_packet = UsbHidTransport.ContPacket(self.packet_size, self.cid, seq,
279                                               next_frame)
280      self.InternalSendPacket(next_packet)
281      seq += 1
282
283  def InternalSendPacket(self, packet):
284    wire = packet.ToWireFormat()
285    self.logger.debug('sending packet: ' + str(wire))
286    self.hid_device.Write(wire)
287
288  def InternalReadFrame(self):
289        # TODO(user): Figure out timeouts.  Today, this implementation
290        # blocks forever at the HID level waiting for a response to a report.
291        # This may not be reasonable behavior (though in practice in seems to be
292        # OK on the set of devices and machines tested so far).
293    frame = self.hid_device.Read()
294    self.logger.debug('recv: ' + str(frame))
295    return frame
296
297  def InternalRecv(self):
298    """Receives a message from the device, including defragmenting it."""
299    first_read = self.InternalReadFrame()
300    first_packet = UsbHidTransport.InitPacket.FromWireFormat(self.packet_size,
301                                                             first_read)
302
303    data = first_packet.payload
304    to_read = first_packet.size - len(first_packet.payload)
305
306    seq = 0
307    while to_read > 0:
308      next_read = self.InternalReadFrame()
309      next_packet = UsbHidTransport.ContPacket.FromWireFormat(self.packet_size,
310                                                              next_read)
311      if self.cid != next_packet.cid:
312        # Skip over packets that are for communication with other clients.
313        # HID is broadcast, so we see potentially all communication from the
314        # device.  For well-behaved devices, these should be BUSY messages
315        # sent to other clients of the device because at this point we're
316        # in mid-message transit.
317        continue
318
319      if seq != next_packet.seq:
320        raise errors.HardwareError('Packets received out of order')
321
322      # This packet for us at this point, so debit it against our
323      # balance of bytes to read.
324      to_read -= len(next_packet.payload)
325
326      data.extend(next_packet.payload)
327      seq += 1
328
329    # truncate incomplete frames
330    data = data[0:first_packet.size]
331    return (first_packet.cmd, data)
332