1# A part of pdfrw (https://github.com/pmaupin/pdfrw)
2# Copyright (C) 2017  Jon Lund Steffensen
3# MIT license -- See LICENSE.txt for details
4
5from __future__ import division
6
7import hashlib
8import struct
9
10try:
11    from Crypto.Cipher import ARC4, AES
12    HAS_CRYPTO = True
13except ImportError:
14    HAS_CRYPTO = False
15
16from .objects import PdfDict, PdfName
17
18_PASSWORD_PAD = (
19    '(\xbfN^Nu\x8aAd\x00NV\xff\xfa\x01\x08'
20    '..\x00\xb6\xd0h>\x80/\x0c\xa9\xfedSiz')
21
22
23def streamobjects(mylist, isinstance=isinstance, PdfDict=PdfDict):
24    for obj in mylist:
25        if isinstance(obj, PdfDict) and obj.stream is not None:
26            yield obj
27
28
29def create_key(password, doc):
30    """Create an encryption key (Algorithm 2 in PDF spec)."""
31    key_size = int(doc.Encrypt.Length or 40) // 8
32    padded_pass = (password + _PASSWORD_PAD)[:32]
33    hasher = hashlib.md5()
34    hasher.update(padded_pass)
35    hasher.update(doc.Encrypt.O.to_bytes())
36    hasher.update(struct.pack('<i', int(doc.Encrypt.P)))
37    hasher.update(doc.ID[0].to_bytes())
38    temp_hash = hasher.digest()
39
40    if int(doc.Encrypt.R or 0) >= 3:
41        for _ in range(50):
42            temp_hash = hashlib.md5(temp_hash[:key_size]).digest()
43
44    return temp_hash[:key_size]
45
46
47def create_user_hash(key, doc):
48    """Create the user password hash (Algorithm 4/5)."""
49    revision = int(doc.Encrypt.R or 0)
50    if revision < 3:
51        cipher = ARC4.new(key)
52        return cipher.encrypt(_PASSWORD_PAD)
53    else:
54        hasher = hashlib.md5()
55        hasher.update(_PASSWORD_PAD)
56        hasher.update(doc.ID[0].to_bytes())
57        temp_hash = hasher.digest()
58
59        for i in range(20):
60            temp_key = ''.join(chr(i ^ ord(x)) for x in key)
61            cipher = ARC4.new(temp_key)
62            temp_hash = cipher.encrypt(temp_hash)
63
64        return temp_hash
65
66
67def check_user_password(key, doc):
68    """Check that the user password is correct (Algorithm 6)."""
69    expect_user_hash = create_user_hash(key, doc)
70    revision = int(doc.Encrypt.R or 0)
71    if revision < 3:
72        return doc.Encrypt.U.to_bytes() == expect_user_hash
73    else:
74        return doc.Encrypt.U.to_bytes()[:16] == expect_user_hash
75
76
77class AESCryptFilter(object):
78    """Crypt filter corresponding to /AESV2."""
79    def __init__(self, key):
80        self._key = key
81
82    def decrypt_data(self, num, gen, data):
83        """Decrypt data (string/stream) using key (Algorithm 1)."""
84        key_extension = struct.pack('<i', num)[:3]
85        key_extension += struct.pack('<i', gen)[:2]
86        key_extension += 'sAlT'
87        temp_key = self._key + key_extension
88        temp_key = hashlib.md5(temp_key).digest()
89
90        iv = data[:AES.block_size]
91        cipher = AES.new(temp_key, AES.MODE_CBC, iv)
92        decrypted = cipher.decrypt(data[AES.block_size:])
93
94        # Remove padding
95        pad_size = ord(decrypted[-1])
96        assert 1 <= pad_size <= 16
97        return decrypted[:-pad_size]
98
99
100class RC4CryptFilter(object):
101    """Crypt filter corresponding to /V2."""
102    def __init__(self, key):
103        self._key = key
104
105    def decrypt_data(self, num, gen, data):
106        """Decrypt data (string/stream) using key (Algorithm 1)."""
107        new_key_size = min(len(self._key) + 5, 16)
108        key_extension = struct.pack('<i', num)[:3]
109        key_extension += struct.pack('<i', gen)[:2]
110        temp_key = self._key + key_extension
111        temp_key = hashlib.md5(temp_key).digest()[:new_key_size]
112
113        cipher = ARC4.new(temp_key)
114        return cipher.decrypt(data)
115
116
117class IdentityCryptFilter(object):
118    """Identity crypt filter (pass through with no encryption)."""
119    def decrypt_data(self, num, gen, data):
120        return data
121
122
123def decrypt_objects(objects, default_filter, filters):
124    """Decrypt list of stream objects.
125
126    The parameter default_filter specifies the default filter to use. The
127    filters parameter is a dictionary of alternate filters to use when the
128    object specfies an alternate filter locally.
129    """
130    for obj in streamobjects(objects):
131        if getattr(obj, 'decrypted', False):
132            continue
133
134        filter = default_filter
135
136        # Check whether a locally defined crypt filter should override the
137        # default filter.
138        ftype = obj.Filter
139        if ftype is not None:
140            if not isinstance(ftype, list):
141                ftype = [ftype]
142            if len(ftype) >= 1 and ftype[0] == PdfName.Crypt:
143                ftype = ftype[1:]
144                parms = obj.DecodeParms or obj.DP
145                filter = filters[parms.Name]
146
147        num, gen = obj.indirect
148        obj.stream = filter.decrypt_data(num, gen, obj.stream)
149        obj.private.decrypted = True
150        obj.Filter = ftype or None
151