1from .core import ( 2 int2bytes, 3 bytes2int, 4 require_version, 5 Version, 6 Tlv, 7 AID, 8 BadResponseError, 9) 10from .core.smartcard import SmartCardConnection, SmartCardProtocol 11 12from urllib.parse import unquote, urlparse, parse_qs 13from functools import total_ordering 14from enum import IntEnum, unique 15from dataclasses import dataclass 16from base64 import b64encode, b32decode 17from time import time 18from typing import Optional, List, Mapping 19 20import hmac 21import hashlib 22import struct 23import os 24import re 25 26 27# TLV tags for credential data 28TAG_NAME = 0x71 29TAG_NAME_LIST = 0x72 30TAG_KEY = 0x73 31TAG_CHALLENGE = 0x74 32TAG_RESPONSE = 0x75 33TAG_TRUNCATED = 0x76 34TAG_HOTP = 0x77 35TAG_PROPERTY = 0x78 36TAG_VERSION = 0x79 37TAG_IMF = 0x7A 38TAG_TOUCH = 0x7C 39 40# Instruction bytes for commands 41INS_LIST = 0xA1 42INS_PUT = 0x01 43INS_DELETE = 0x02 44INS_SET_CODE = 0x03 45INS_RESET = 0x04 46INS_RENAME = 0x05 47INS_CALCULATE = 0xA2 48INS_VALIDATE = 0xA3 49INS_CALCULATE_ALL = 0xA4 50INS_SEND_REMAINING = 0xA5 51 52TOTP_ID_PATTERN = re.compile(r"^((\d+)/)?(([^:]+):)?(.+)$") 53 54MASK_ALGO = 0x0F 55MASK_TYPE = 0xF0 56 57DEFAULT_PERIOD = 30 58DEFAULT_DIGITS = 6 59DEFAULT_IMF = 0 60CHALLENGE_LEN = 8 61HMAC_MINIMUM_KEY_SIZE = 14 62 63 64@unique 65class HASH_ALGORITHM(IntEnum): 66 SHA1 = 0x01 67 SHA256 = 0x02 68 SHA512 = 0x03 69 70 71@unique 72class OATH_TYPE(IntEnum): 73 HOTP = 0x10 74 TOTP = 0x20 75 76 77PROP_REQUIRE_TOUCH = 0x02 78 79 80def parse_b32_key(key: str): 81 key = key.upper().replace(" ", "") 82 key += "=" * (-len(key) % 8) # Support unpadded 83 return b32decode(key) 84 85 86def _parse_select(response): 87 data = Tlv.parse_dict(response) 88 return ( 89 Version.from_bytes(data[TAG_VERSION]), 90 data.get(TAG_NAME), 91 data.get(TAG_CHALLENGE), 92 ) 93 94 95@dataclass 96class CredentialData: 97 name: str 98 oath_type: OATH_TYPE 99 hash_algorithm: HASH_ALGORITHM 100 secret: bytes 101 digits: int = DEFAULT_DIGITS 102 period: int = DEFAULT_PERIOD 103 counter: int = DEFAULT_IMF 104 issuer: Optional[str] = None 105 106 @classmethod 107 def parse_uri(cls, uri: str) -> "CredentialData": 108 parsed = urlparse(uri.strip()) 109 if parsed.scheme != "otpauth": 110 raise ValueError("Invalid URI scheme") 111 112 if parsed.hostname is None: 113 raise ValueError("Missing OATH type") 114 oath_type = OATH_TYPE[parsed.hostname.upper()] 115 116 params = dict((k, v[0]) for k, v in parse_qs(parsed.query).items()) 117 issuer = None 118 name = unquote(parsed.path)[1:] # Unquote and strip leading / 119 if ":" in name: 120 issuer, name = name.split(":", 1) 121 122 return cls( 123 name=name, 124 oath_type=oath_type, 125 hash_algorithm=HASH_ALGORITHM[params.get("algorithm", "SHA1").upper()], 126 secret=parse_b32_key(params["secret"]), 127 digits=int(params.get("digits", DEFAULT_DIGITS)), 128 period=int(params.get("period", DEFAULT_PERIOD)), 129 counter=int(params.get("counter", DEFAULT_IMF)), 130 issuer=params.get("issuer", issuer), 131 ) 132 133 def get_id(self) -> bytes: 134 return _format_cred_id(self.issuer, self.name, self.oath_type, self.period) 135 136 137@dataclass 138class Code: 139 value: str 140 valid_from: int 141 valid_to: int 142 143 144@total_ordering 145@dataclass(order=False, frozen=True) 146class Credential: 147 device_id: str 148 id: bytes 149 issuer: Optional[str] 150 name: str 151 oath_type: OATH_TYPE 152 period: int 153 touch_required: Optional[bool] 154 155 def __lt__(self, other): 156 a = ((self.issuer or self.name).lower(), self.name.lower()) 157 b = ((other.issuer or other.name).lower(), other.name.lower()) 158 return a < b 159 160 def __eq__(self, other): 161 return ( 162 isinstance(other, type(self)) 163 and self.device_id == other.device_id 164 and self.id == other.id 165 ) 166 167 def __hash__(self): 168 return hash((self.device_id, self.id)) 169 170 171def _format_cred_id(issuer, name, oath_type, period=DEFAULT_PERIOD): 172 cred_id = "" 173 if oath_type == OATH_TYPE.TOTP and period != DEFAULT_PERIOD: 174 cred_id += "%d/" % period 175 if issuer: 176 cred_id += issuer + ":" 177 cred_id += name 178 return cred_id.encode() 179 180 181def _parse_cred_id(cred_id, oath_type): 182 data = cred_id.decode() 183 if oath_type == OATH_TYPE.TOTP: 184 match = TOTP_ID_PATTERN.match(data) 185 if match: 186 period_str = match.group(2) 187 return ( 188 match.group(4), 189 match.group(5), 190 int(period_str) if period_str else DEFAULT_PERIOD, 191 ) 192 else: 193 return None, data, DEFAULT_PERIOD 194 else: 195 if ":" in data: 196 issuer, data = data.split(":", 1) 197 else: 198 issuer = None 199 return issuer, data, None 200 201 202def _get_device_id(salt): 203 d = hashlib.sha256(salt).digest()[:16] 204 return b64encode(d).replace(b"=", b"").decode() 205 206 207def _hmac_sha1(key, message): 208 return hmac.new(key, message, "sha1").digest() 209 210 211def _derive_key(salt, passphrase): 212 return hashlib.pbkdf2_hmac("sha1", passphrase.encode(), salt, 1000, 16) 213 214 215def _hmac_shorten_key(key, algo): 216 h = hashlib.new(algo.name) 217 218 if len(key) > h.block_size: 219 h.update(key) 220 key = h.digest() 221 return key 222 223 224def _get_challenge(timestamp, period): 225 time_step = timestamp // period 226 return struct.pack(">q", time_step) 227 228 229def _format_code(credential, timestamp, truncated): 230 if credential.oath_type == OATH_TYPE.TOTP: 231 time_step = timestamp // credential.period 232 valid_from = time_step * credential.period 233 valid_to = (time_step + 1) * credential.period 234 else: # HOTP 235 valid_from = timestamp 236 valid_to = float("Inf") 237 digits = truncated[0] 238 239 return Code( 240 str((bytes2int(truncated[1:]) & 0x7FFFFFFF) % 10 ** digits).rjust(digits, "0"), 241 valid_from, 242 valid_to, 243 ) 244 245 246class OathSession: 247 def __init__(self, connection: SmartCardConnection): 248 self.protocol = SmartCardProtocol(connection, INS_SEND_REMAINING) 249 self._version, self._salt, self._challenge = _parse_select( 250 self.protocol.select(AID.OATH) 251 ) 252 self._has_key = self._challenge is not None 253 self._device_id = _get_device_id(self._salt) 254 self.protocol.enable_touch_workaround(self._version) 255 256 @property 257 def version(self) -> Version: 258 return self._version 259 260 @property 261 def device_id(self) -> str: 262 return self._device_id 263 264 @property 265 def has_key(self) -> bool: 266 return self._has_key 267 268 @property 269 def locked(self) -> bool: 270 return self._challenge is not None 271 272 def reset(self) -> None: 273 self.protocol.send_apdu(0, INS_RESET, 0xDE, 0xAD) 274 _, self._salt, self._challenge = _parse_select(self.protocol.select(AID.OATH)) 275 self._has_key = False 276 self._device_id = _get_device_id(self._salt) 277 278 def derive_key(self, password: str) -> bytes: 279 return _derive_key(self._salt, password) 280 281 def validate(self, key: bytes) -> None: 282 response = _hmac_sha1(key, self._challenge) 283 challenge = os.urandom(8) 284 data = Tlv(TAG_RESPONSE, response) + Tlv(TAG_CHALLENGE, challenge) 285 resp = self.protocol.send_apdu(0, INS_VALIDATE, 0, 0, data) 286 verification = _hmac_sha1(key, challenge) 287 if not hmac.compare_digest(Tlv.unpack(TAG_RESPONSE, resp), verification): 288 raise BadResponseError( 289 "Response from validation does not match verification!" 290 ) 291 self._challenge = None 292 293 def set_key(self, key: bytes) -> None: 294 challenge = os.urandom(8) 295 response = _hmac_sha1(key, challenge) 296 self.protocol.send_apdu( 297 0, 298 INS_SET_CODE, 299 0, 300 0, 301 ( 302 Tlv(TAG_KEY, int2bytes(OATH_TYPE.TOTP | HASH_ALGORITHM.SHA1) + key) 303 + Tlv(TAG_CHALLENGE, challenge) 304 + Tlv(TAG_RESPONSE, response) 305 ), 306 ) 307 self._has_key = True 308 309 def unset_key(self) -> None: 310 self.protocol.send_apdu(0, INS_SET_CODE, 0, 0, Tlv(TAG_KEY)) 311 self._has_key = False 312 313 def put_credential( 314 self, credential_data: CredentialData, touch_required: bool = False 315 ) -> Credential: 316 d = credential_data 317 cred_id = d.get_id() 318 secret = _hmac_shorten_key(d.secret, d.hash_algorithm) 319 secret = secret.ljust(HMAC_MINIMUM_KEY_SIZE, b"\0") 320 data = Tlv(TAG_NAME, cred_id) + Tlv( 321 TAG_KEY, 322 struct.pack("<BB", d.oath_type | d.hash_algorithm, d.digits) + secret, 323 ) 324 325 if touch_required: 326 data += struct.pack(b">BB", TAG_PROPERTY, PROP_REQUIRE_TOUCH) 327 328 if d.counter > 0: 329 data += Tlv(TAG_IMF, struct.pack(">I", d.counter)) 330 331 self.protocol.send_apdu(0, INS_PUT, 0, 0, data) 332 return Credential( 333 self.device_id, 334 cred_id, 335 d.issuer, 336 d.name, 337 d.oath_type, 338 d.period, 339 touch_required, 340 ) 341 342 def rename_credential( 343 self, credential_id: bytes, name: str, issuer: Optional[str] = None 344 ) -> bytes: 345 require_version(self.version, (5, 3, 1)) 346 _, _, period = _parse_cred_id(credential_id, OATH_TYPE.TOTP) 347 new_id = _format_cred_id(issuer, name, OATH_TYPE.TOTP, period) 348 self.protocol.send_apdu( 349 0, INS_RENAME, 0, 0, Tlv(TAG_NAME, credential_id) + Tlv(TAG_NAME, new_id) 350 ) 351 return new_id 352 353 def list_credentials(self) -> List[Credential]: 354 creds = [] 355 for tlv in Tlv.parse_list(self.protocol.send_apdu(0, INS_LIST, 0, 0)): 356 data = Tlv.unpack(TAG_NAME_LIST, tlv) 357 oath_type = OATH_TYPE(MASK_TYPE & data[0]) 358 cred_id = data[1:] 359 issuer, name, period = _parse_cred_id(cred_id, oath_type) 360 creds.append( 361 Credential( 362 self.device_id, cred_id, issuer, name, oath_type, period, None 363 ) 364 ) 365 return creds 366 367 def calculate(self, credential_id: bytes, challenge: bytes) -> bytes: 368 resp = Tlv.unpack( 369 TAG_RESPONSE, 370 self.protocol.send_apdu( 371 0, 372 INS_CALCULATE, 373 0, 374 0, 375 Tlv(TAG_NAME, credential_id) + Tlv(TAG_CHALLENGE, challenge), 376 ), 377 ) 378 return resp[1:] 379 380 def delete_credential(self, credential_id: bytes) -> None: 381 self.protocol.send_apdu(0, INS_DELETE, 0, 0, Tlv(TAG_NAME, credential_id)) 382 383 def calculate_all( 384 self, timestamp: Optional[int] = None 385 ) -> Mapping[Credential, Optional[Code]]: 386 timestamp = int(timestamp or time()) 387 challenge = _get_challenge(timestamp, DEFAULT_PERIOD) 388 389 entries = {} 390 data = Tlv.parse_list( 391 self.protocol.send_apdu( 392 0, INS_CALCULATE_ALL, 0, 1, Tlv(TAG_CHALLENGE, challenge) 393 ) 394 ) 395 while data: 396 cred_id = Tlv.unpack(TAG_NAME, data.pop(0)) 397 tlv = data.pop(0) 398 resp_tag = tlv.tag 399 oath_type = OATH_TYPE.HOTP if resp_tag == TAG_HOTP else OATH_TYPE.TOTP 400 touch = resp_tag == TAG_TOUCH 401 issuer, name, period = _parse_cred_id(cred_id, oath_type) 402 403 credential = Credential( 404 self.device_id, cred_id, issuer, name, oath_type, period, touch 405 ) 406 407 code = None # Will be None for HOTP and touch 408 if resp_tag == TAG_TRUNCATED: # Only TOTP, no-touch here 409 if period == DEFAULT_PERIOD: 410 code = _format_code(credential, timestamp, tlv.value) 411 else: 412 # Non-standard period, recalculate 413 code = self.calculate_code(credential, timestamp) 414 entries[credential] = code 415 416 return entries 417 418 def calculate_code( 419 self, credential: Credential, timestamp: Optional[int] = None 420 ) -> Code: 421 if credential.device_id != self.device_id: 422 raise ValueError("Credential does not belong to this YubiKey") 423 424 timestamp = int(timestamp or time()) 425 if credential.oath_type == OATH_TYPE.TOTP: 426 challenge = _get_challenge(timestamp, credential.period) 427 else: # HOTP 428 challenge = b"" 429 430 response = Tlv.unpack( 431 TAG_TRUNCATED, 432 self.protocol.send_apdu( 433 0, 434 INS_CALCULATE, 435 0, 436 0x01, # Truncate 437 Tlv(TAG_NAME, credential.id) + Tlv(TAG_CHALLENGE, challenge), 438 ), 439 ) 440 return _format_code(credential, timestamp, response) 441