1from .core import (
2    int2bytes,
3    bytes2int,
4    require_version,
5    Version,
6    Tlv,
7    AID,
8    BadResponseError,
9)
10from .core.smartcard import SmartCardConnection, SmartCardProtocol
11
12from urllib.parse import unquote, urlparse, parse_qs
13from functools import total_ordering
14from enum import IntEnum, unique
15from dataclasses import dataclass
16from base64 import b64encode, b32decode
17from time import time
18from typing import Optional, List, Mapping
19
20import hmac
21import hashlib
22import struct
23import os
24import re
25
26
27# TLV tags for credential data
28TAG_NAME = 0x71
29TAG_NAME_LIST = 0x72
30TAG_KEY = 0x73
31TAG_CHALLENGE = 0x74
32TAG_RESPONSE = 0x75
33TAG_TRUNCATED = 0x76
34TAG_HOTP = 0x77
35TAG_PROPERTY = 0x78
36TAG_VERSION = 0x79
37TAG_IMF = 0x7A
38TAG_TOUCH = 0x7C
39
40# Instruction bytes for commands
41INS_LIST = 0xA1
42INS_PUT = 0x01
43INS_DELETE = 0x02
44INS_SET_CODE = 0x03
45INS_RESET = 0x04
46INS_RENAME = 0x05
47INS_CALCULATE = 0xA2
48INS_VALIDATE = 0xA3
49INS_CALCULATE_ALL = 0xA4
50INS_SEND_REMAINING = 0xA5
51
52TOTP_ID_PATTERN = re.compile(r"^((\d+)/)?(([^:]+):)?(.+)$")
53
54MASK_ALGO = 0x0F
55MASK_TYPE = 0xF0
56
57DEFAULT_PERIOD = 30
58DEFAULT_DIGITS = 6
59DEFAULT_IMF = 0
60CHALLENGE_LEN = 8
61HMAC_MINIMUM_KEY_SIZE = 14
62
63
64@unique
65class HASH_ALGORITHM(IntEnum):
66    SHA1 = 0x01
67    SHA256 = 0x02
68    SHA512 = 0x03
69
70
71@unique
72class OATH_TYPE(IntEnum):
73    HOTP = 0x10
74    TOTP = 0x20
75
76
77PROP_REQUIRE_TOUCH = 0x02
78
79
80def parse_b32_key(key: str):
81    key = key.upper().replace(" ", "")
82    key += "=" * (-len(key) % 8)  # Support unpadded
83    return b32decode(key)
84
85
86def _parse_select(response):
87    data = Tlv.parse_dict(response)
88    return (
89        Version.from_bytes(data[TAG_VERSION]),
90        data.get(TAG_NAME),
91        data.get(TAG_CHALLENGE),
92    )
93
94
95@dataclass
96class CredentialData:
97    name: str
98    oath_type: OATH_TYPE
99    hash_algorithm: HASH_ALGORITHM
100    secret: bytes
101    digits: int = DEFAULT_DIGITS
102    period: int = DEFAULT_PERIOD
103    counter: int = DEFAULT_IMF
104    issuer: Optional[str] = None
105
106    @classmethod
107    def parse_uri(cls, uri: str) -> "CredentialData":
108        parsed = urlparse(uri.strip())
109        if parsed.scheme != "otpauth":
110            raise ValueError("Invalid URI scheme")
111
112        if parsed.hostname is None:
113            raise ValueError("Missing OATH type")
114        oath_type = OATH_TYPE[parsed.hostname.upper()]
115
116        params = dict((k, v[0]) for k, v in parse_qs(parsed.query).items())
117        issuer = None
118        name = unquote(parsed.path)[1:]  # Unquote and strip leading /
119        if ":" in name:
120            issuer, name = name.split(":", 1)
121
122        return cls(
123            name=name,
124            oath_type=oath_type,
125            hash_algorithm=HASH_ALGORITHM[params.get("algorithm", "SHA1").upper()],
126            secret=parse_b32_key(params["secret"]),
127            digits=int(params.get("digits", DEFAULT_DIGITS)),
128            period=int(params.get("period", DEFAULT_PERIOD)),
129            counter=int(params.get("counter", DEFAULT_IMF)),
130            issuer=params.get("issuer", issuer),
131        )
132
133    def get_id(self) -> bytes:
134        return _format_cred_id(self.issuer, self.name, self.oath_type, self.period)
135
136
137@dataclass
138class Code:
139    value: str
140    valid_from: int
141    valid_to: int
142
143
144@total_ordering
145@dataclass(order=False, frozen=True)
146class Credential:
147    device_id: str
148    id: bytes
149    issuer: Optional[str]
150    name: str
151    oath_type: OATH_TYPE
152    period: int
153    touch_required: Optional[bool]
154
155    def __lt__(self, other):
156        a = ((self.issuer or self.name).lower(), self.name.lower())
157        b = ((other.issuer or other.name).lower(), other.name.lower())
158        return a < b
159
160    def __eq__(self, other):
161        return (
162            isinstance(other, type(self))
163            and self.device_id == other.device_id
164            and self.id == other.id
165        )
166
167    def __hash__(self):
168        return hash((self.device_id, self.id))
169
170
171def _format_cred_id(issuer, name, oath_type, period=DEFAULT_PERIOD):
172    cred_id = ""
173    if oath_type == OATH_TYPE.TOTP and period != DEFAULT_PERIOD:
174        cred_id += "%d/" % period
175    if issuer:
176        cred_id += issuer + ":"
177    cred_id += name
178    return cred_id.encode()
179
180
181def _parse_cred_id(cred_id, oath_type):
182    data = cred_id.decode()
183    if oath_type == OATH_TYPE.TOTP:
184        match = TOTP_ID_PATTERN.match(data)
185        if match:
186            period_str = match.group(2)
187            return (
188                match.group(4),
189                match.group(5),
190                int(period_str) if period_str else DEFAULT_PERIOD,
191            )
192        else:
193            return None, data, DEFAULT_PERIOD
194    else:
195        if ":" in data:
196            issuer, data = data.split(":", 1)
197        else:
198            issuer = None
199    return issuer, data, None
200
201
202def _get_device_id(salt):
203    d = hashlib.sha256(salt).digest()[:16]
204    return b64encode(d).replace(b"=", b"").decode()
205
206
207def _hmac_sha1(key, message):
208    return hmac.new(key, message, "sha1").digest()
209
210
211def _derive_key(salt, passphrase):
212    return hashlib.pbkdf2_hmac("sha1", passphrase.encode(), salt, 1000, 16)
213
214
215def _hmac_shorten_key(key, algo):
216    h = hashlib.new(algo.name)
217
218    if len(key) > h.block_size:
219        h.update(key)
220        key = h.digest()
221    return key
222
223
224def _get_challenge(timestamp, period):
225    time_step = timestamp // period
226    return struct.pack(">q", time_step)
227
228
229def _format_code(credential, timestamp, truncated):
230    if credential.oath_type == OATH_TYPE.TOTP:
231        time_step = timestamp // credential.period
232        valid_from = time_step * credential.period
233        valid_to = (time_step + 1) * credential.period
234    else:  # HOTP
235        valid_from = timestamp
236        valid_to = float("Inf")
237    digits = truncated[0]
238
239    return Code(
240        str((bytes2int(truncated[1:]) & 0x7FFFFFFF) % 10 ** digits).rjust(digits, "0"),
241        valid_from,
242        valid_to,
243    )
244
245
246class OathSession:
247    def __init__(self, connection: SmartCardConnection):
248        self.protocol = SmartCardProtocol(connection, INS_SEND_REMAINING)
249        self._version, self._salt, self._challenge = _parse_select(
250            self.protocol.select(AID.OATH)
251        )
252        self._has_key = self._challenge is not None
253        self._device_id = _get_device_id(self._salt)
254        self.protocol.enable_touch_workaround(self._version)
255
256    @property
257    def version(self) -> Version:
258        return self._version
259
260    @property
261    def device_id(self) -> str:
262        return self._device_id
263
264    @property
265    def has_key(self) -> bool:
266        return self._has_key
267
268    @property
269    def locked(self) -> bool:
270        return self._challenge is not None
271
272    def reset(self) -> None:
273        self.protocol.send_apdu(0, INS_RESET, 0xDE, 0xAD)
274        _, self._salt, self._challenge = _parse_select(self.protocol.select(AID.OATH))
275        self._has_key = False
276        self._device_id = _get_device_id(self._salt)
277
278    def derive_key(self, password: str) -> bytes:
279        return _derive_key(self._salt, password)
280
281    def validate(self, key: bytes) -> None:
282        response = _hmac_sha1(key, self._challenge)
283        challenge = os.urandom(8)
284        data = Tlv(TAG_RESPONSE, response) + Tlv(TAG_CHALLENGE, challenge)
285        resp = self.protocol.send_apdu(0, INS_VALIDATE, 0, 0, data)
286        verification = _hmac_sha1(key, challenge)
287        if not hmac.compare_digest(Tlv.unpack(TAG_RESPONSE, resp), verification):
288            raise BadResponseError(
289                "Response from validation does not match verification!"
290            )
291        self._challenge = None
292
293    def set_key(self, key: bytes) -> None:
294        challenge = os.urandom(8)
295        response = _hmac_sha1(key, challenge)
296        self.protocol.send_apdu(
297            0,
298            INS_SET_CODE,
299            0,
300            0,
301            (
302                Tlv(TAG_KEY, int2bytes(OATH_TYPE.TOTP | HASH_ALGORITHM.SHA1) + key)
303                + Tlv(TAG_CHALLENGE, challenge)
304                + Tlv(TAG_RESPONSE, response)
305            ),
306        )
307        self._has_key = True
308
309    def unset_key(self) -> None:
310        self.protocol.send_apdu(0, INS_SET_CODE, 0, 0, Tlv(TAG_KEY))
311        self._has_key = False
312
313    def put_credential(
314        self, credential_data: CredentialData, touch_required: bool = False
315    ) -> Credential:
316        d = credential_data
317        cred_id = d.get_id()
318        secret = _hmac_shorten_key(d.secret, d.hash_algorithm)
319        secret = secret.ljust(HMAC_MINIMUM_KEY_SIZE, b"\0")
320        data = Tlv(TAG_NAME, cred_id) + Tlv(
321            TAG_KEY,
322            struct.pack("<BB", d.oath_type | d.hash_algorithm, d.digits) + secret,
323        )
324
325        if touch_required:
326            data += struct.pack(b">BB", TAG_PROPERTY, PROP_REQUIRE_TOUCH)
327
328        if d.counter > 0:
329            data += Tlv(TAG_IMF, struct.pack(">I", d.counter))
330
331        self.protocol.send_apdu(0, INS_PUT, 0, 0, data)
332        return Credential(
333            self.device_id,
334            cred_id,
335            d.issuer,
336            d.name,
337            d.oath_type,
338            d.period,
339            touch_required,
340        )
341
342    def rename_credential(
343        self, credential_id: bytes, name: str, issuer: Optional[str] = None
344    ) -> bytes:
345        require_version(self.version, (5, 3, 1))
346        _, _, period = _parse_cred_id(credential_id, OATH_TYPE.TOTP)
347        new_id = _format_cred_id(issuer, name, OATH_TYPE.TOTP, period)
348        self.protocol.send_apdu(
349            0, INS_RENAME, 0, 0, Tlv(TAG_NAME, credential_id) + Tlv(TAG_NAME, new_id)
350        )
351        return new_id
352
353    def list_credentials(self) -> List[Credential]:
354        creds = []
355        for tlv in Tlv.parse_list(self.protocol.send_apdu(0, INS_LIST, 0, 0)):
356            data = Tlv.unpack(TAG_NAME_LIST, tlv)
357            oath_type = OATH_TYPE(MASK_TYPE & data[0])
358            cred_id = data[1:]
359            issuer, name, period = _parse_cred_id(cred_id, oath_type)
360            creds.append(
361                Credential(
362                    self.device_id, cred_id, issuer, name, oath_type, period, None
363                )
364            )
365        return creds
366
367    def calculate(self, credential_id: bytes, challenge: bytes) -> bytes:
368        resp = Tlv.unpack(
369            TAG_RESPONSE,
370            self.protocol.send_apdu(
371                0,
372                INS_CALCULATE,
373                0,
374                0,
375                Tlv(TAG_NAME, credential_id) + Tlv(TAG_CHALLENGE, challenge),
376            ),
377        )
378        return resp[1:]
379
380    def delete_credential(self, credential_id: bytes) -> None:
381        self.protocol.send_apdu(0, INS_DELETE, 0, 0, Tlv(TAG_NAME, credential_id))
382
383    def calculate_all(
384        self, timestamp: Optional[int] = None
385    ) -> Mapping[Credential, Optional[Code]]:
386        timestamp = int(timestamp or time())
387        challenge = _get_challenge(timestamp, DEFAULT_PERIOD)
388
389        entries = {}
390        data = Tlv.parse_list(
391            self.protocol.send_apdu(
392                0, INS_CALCULATE_ALL, 0, 1, Tlv(TAG_CHALLENGE, challenge)
393            )
394        )
395        while data:
396            cred_id = Tlv.unpack(TAG_NAME, data.pop(0))
397            tlv = data.pop(0)
398            resp_tag = tlv.tag
399            oath_type = OATH_TYPE.HOTP if resp_tag == TAG_HOTP else OATH_TYPE.TOTP
400            touch = resp_tag == TAG_TOUCH
401            issuer, name, period = _parse_cred_id(cred_id, oath_type)
402
403            credential = Credential(
404                self.device_id, cred_id, issuer, name, oath_type, period, touch
405            )
406
407            code = None  # Will be None for HOTP and touch
408            if resp_tag == TAG_TRUNCATED:  # Only TOTP, no-touch here
409                if period == DEFAULT_PERIOD:
410                    code = _format_code(credential, timestamp, tlv.value)
411                else:
412                    # Non-standard period, recalculate
413                    code = self.calculate_code(credential, timestamp)
414            entries[credential] = code
415
416        return entries
417
418    def calculate_code(
419        self, credential: Credential, timestamp: Optional[int] = None
420    ) -> Code:
421        if credential.device_id != self.device_id:
422            raise ValueError("Credential does not belong to this YubiKey")
423
424        timestamp = int(timestamp or time())
425        if credential.oath_type == OATH_TYPE.TOTP:
426            challenge = _get_challenge(timestamp, credential.period)
427        else:  # HOTP
428            challenge = b""
429
430        response = Tlv.unpack(
431            TAG_TRUNCATED,
432            self.protocol.send_apdu(
433                0,
434                INS_CALCULATE,
435                0,
436                0x01,  # Truncate
437                Tlv(TAG_NAME, credential.id) + Tlv(TAG_CHALLENGE, challenge),
438            ),
439        )
440        return _format_code(credential, timestamp, response)
441