1#!/usr/local/bin/python3.8
2# vim:fileencoding=UTF-8:ts=4:sw=4:sta:et:sts=4:ai
3
4
5__license__   = 'GPL v3'
6__copyright__ = '2011, Kovid Goyal <kovid@kovidgoyal.net>'
7__docformat__ = 'restructuredtext en'
8
9import struct, string, zlib, os
10from collections import OrderedDict
11from io import BytesIO
12
13from calibre.utils.img import save_cover_data_to, scale_image, image_to_data, image_from_data, resize_image, png_data_to_gif_data
14from calibre.utils.imghdr import what
15from calibre.ebooks import normalize
16from polyglot.builtins import as_bytes
17from tinycss.color3 import parse_color_string
18
19IMAGE_MAX_SIZE = 10 * 1024 * 1024
20RECORD_SIZE = 0x1000  # 4096 (Text record size (uncompressed))
21
22
23class PolyglotDict(dict):
24
25    def __setitem__(self, key, val):
26        if isinstance(key, str):
27            key = key.encode('utf-8')
28        dict.__setitem__(self, key, val)
29
30    def __getitem__(self, key):
31        if isinstance(key, str):
32            key = key.encode('utf-8')
33        return dict.__getitem__(self, key)
34
35    def __contains__(self, key):
36        if isinstance(key, str):
37            key = key.encode('utf-8')
38        return dict.__contains__(self, key)
39
40
41def decode_string(raw, codec='utf-8', ordt_map=None):
42    length, = struct.unpack(b'>B', raw[0:1])
43    raw = raw[1:1+length]
44    consumed = length+1
45    if ordt_map:
46        return ''.join(ordt_map[x] for x in bytearray(raw)), consumed
47    return raw.decode(codec), consumed
48
49
50def decode_hex_number(raw, codec='utf-8'):
51    '''
52    Return a variable length number encoded using hexadecimal encoding. These
53    numbers have the first byte which tells the number of bytes that follow.
54    The bytes that follow are simply the hexadecimal representation of the
55    number.
56
57    :param raw: Raw binary data as a bytestring
58
59    :return: The number and the number of bytes from raw that the number
60    occupies.
61    '''
62    raw, consumed = decode_string(raw, codec=codec)
63    return int(raw, 16), consumed
64
65
66def encode_string(raw):
67    ans = bytearray(as_bytes(raw))
68    ans.insert(0, len(ans))
69    return bytes(ans)
70
71
72def encode_number_as_hex(num):
73    '''
74    Encode num as a variable length encoded hexadecimal number. Returns the
75    bytestring containing the encoded number. These
76    numbers have the first byte which tells the number of bytes that follow.
77    The bytes that follow are simply the hexadecimal representation of the
78    number.
79    '''
80    num = hex(num)[2:].upper().encode('ascii')
81    nlen = len(num)
82    if nlen % 2 != 0:
83        num = b'0'+num
84    return encode_string(num)
85
86
87def encint(value, forward=True):
88    '''
89    Some parts of the Mobipocket format encode data as variable-width integers.
90    These integers are represented big-endian with 7 bits per byte in bits 1-7.
91    They may be either forward-encoded, in which case only the first byte has bit 8 set,
92    or backward-encoded, in which case only the last byte has bit 8 set.
93    For example, the number 0x11111 = 0b10001000100010001 would be represented
94    forward-encoded as:
95
96        0x04 0x22 0x91 = 0b100 0b100010 0b10010001
97
98    And backward-encoded as:
99
100        0x84 0x22 0x11 = 0b10000100 0b100010 0b10001
101
102    This function encodes the integer ``value`` as a variable width integer and
103    returns the bytestring corresponding to it.
104
105    If forward is True the bytes returned are suitable for prepending to the
106    output buffer, otherwise they must be append to the output buffer.
107    '''
108    if value < 0:
109        raise ValueError('Cannot encode negative numbers as vwi')
110    # Encode vwi
111    byts = bytearray()
112    while True:
113        b = value & 0b01111111
114        value >>= 7  # shift value to the right by 7 bits
115
116        byts.append(b)
117        if value == 0:
118            break
119    byts[0 if forward else -1] |= 0b10000000
120    byts.reverse()
121    return bytes(byts)
122
123
124def decint(raw, forward=True):
125    '''
126    Read a variable width integer from the bytestring or bytearray raw and return the
127    integer and the number of bytes read. If forward is True bytes are read
128    from the start of raw, otherwise from the end of raw.
129
130    This function is the inverse of encint above, see its docs for more
131    details.
132    '''
133    val = 0
134    byts = bytearray()
135    src = bytearray(raw)
136    if not forward:
137        src.reverse()
138    for bnum in src:
139        byts.append(bnum & 0b01111111)
140        if bnum & 0b10000000:
141            break
142    if not forward:
143        byts.reverse()
144    for byte in byts:
145        val <<= 7  # Shift value to the left by 7 bits
146        val |= byte
147
148    return val, len(byts)
149
150
151def test_decint(num):
152    for d in (True, False):
153        raw = encint(num, forward=d)
154        sz = len(raw)
155        if (num, sz) != decint(raw, forward=d):
156            raise ValueError('Failed for num %d, forward=%r: %r != %r' % (
157                num, d, (num, sz), decint(raw, forward=d)))
158
159
160def rescale_image(data, maxsizeb=IMAGE_MAX_SIZE, dimen=None):
161    '''
162    Convert image setting all transparent pixels to white and changing format
163    to JPEG. Ensure the resultant image has a byte size less than
164    maxsizeb.
165
166    If dimen is not None, generate a thumbnail of
167    width=dimen, height=dimen or width, height = dimen (depending on the type
168    of dimen)
169
170    Returns the image as a bytestring
171    '''
172    if dimen is not None:
173        if hasattr(dimen, '__len__'):
174            width, height = dimen
175        else:
176            width = height = dimen
177        data = scale_image(data, width=width, height=height, compression_quality=90)[-1]
178    else:
179        # Replace transparent pixels with white pixels and convert to JPEG
180        data = save_cover_data_to(data)
181    if len(data) <= maxsizeb:
182        return data
183    orig_data = data  # save it in case compression fails
184    quality = 90
185    while len(data) > maxsizeb and quality >= 5:
186        data = image_to_data(image_from_data(orig_data), compression_quality=quality)
187        quality -= 5
188    if len(data) <= maxsizeb:
189        return data
190    orig_data = data
191
192    scale = 0.9
193    while len(data) > maxsizeb and scale >= 0.05:
194        img = image_from_data(data)
195        w, h = img.width(), img.height()
196        img = resize_image(img, int(scale*w), int(scale*h))
197        data = image_to_data(img, compression_quality=quality)
198        scale -= 0.05
199    return data
200
201
202def get_trailing_data(record, extra_data_flags):
203    '''
204    Given a text record as a bytestring and the extra data flags from the MOBI
205    header, return the trailing data as a dictionary, mapping bit number to
206    data as bytestring. Also returns the record - all trailing data.
207
208    :return: Trailing data, record - trailing data
209    '''
210    data = OrderedDict()
211    flags = extra_data_flags >> 1
212
213    num = 0
214    while flags:
215        num += 1
216        if flags & 0b1:
217            sz, consumed = decint(record, forward=False)
218            if sz > consumed:
219                data[num] = record[-sz:-consumed]
220            record = record[:-sz]
221        flags >>= 1
222    # Read multibyte chars if any
223    if extra_data_flags & 0b1:
224        # Only the first two bits are used for the size since there can
225        # never be more than 3 trailing multibyte chars
226        sz = (ord(record[-1:]) & 0b11) + 1
227        consumed = 1
228        if sz > consumed:
229            data[0] = record[-sz:-consumed]
230        record = record[:-sz]
231    return data, record
232
233
234def encode_trailing_data(raw):
235    '''
236    Given some data in the bytestring raw, return a bytestring of the form
237
238        <data><size>
239
240    where size is a backwards encoded vwi whose value is the length of the
241    entire returned bytestring. data is the bytestring passed in as raw.
242
243    This is the encoding used for trailing data entries at the end of text
244    records. See get_trailing_data() for details.
245    '''
246    lsize = 1
247    while True:
248        encoded = encint(len(raw) + lsize, forward=False)
249        if len(encoded) == lsize:
250            break
251        lsize += 1
252    return raw + encoded
253
254
255def encode_fvwi(val, flags, flag_size=4):
256    '''
257    Encode the value val and the flag_size bits from flags as a fvwi. This encoding is
258    used in the trailing byte sequences for indexing. Returns encoded
259    bytestring.
260    '''
261    ans = val << flag_size
262    for i in range(flag_size):
263        ans |= (flags & (1 << i))
264    return encint(ans)
265
266
267def decode_fvwi(byts, flag_size=4):
268    '''
269    Decode encoded fvwi. Returns number, flags, consumed
270    '''
271    arg, consumed = decint(bytes(byts))
272    val = arg >> flag_size
273    flags = 0
274    for i in range(flag_size):
275        flags |= (arg & (1 << i))
276    return val, flags, consumed
277
278
279def decode_tbs(byts, flag_size=4):
280    '''
281    Trailing byte sequences for indexing consists of series of fvwi numbers.
282    This function reads the fvwi number and its associated flags. It then uses
283    the flags to read any more numbers that belong to the series. The flags are
284    the lowest 4 bits of the vwi (see the encode_fvwi function above).
285
286    Returns the fvwi number, a dictionary mapping flags bits to the associated
287    data and the number of bytes consumed.
288    '''
289    byts = bytes(byts)
290    val, flags, consumed = decode_fvwi(byts, flag_size=flag_size)
291    extra = {}
292    byts = byts[consumed:]
293    if flags & 0b1000 and flag_size > 3:
294        extra[0b1000] = True
295    if flags & 0b0010:
296        x, consumed2 = decint(byts)
297        byts = byts[consumed2:]
298        extra[0b0010] = x
299        consumed += consumed2
300    if flags & 0b0100:
301        extra[0b0100] = ord(byts[0:1])
302        byts = byts[1:]
303        consumed += 1
304    if flags & 0b0001:
305        x, consumed2 = decint(byts)
306        byts = byts[consumed2:]
307        extra[0b0001] = x
308        consumed += consumed2
309    return val, extra, consumed
310
311
312def encode_tbs(val, extra, flag_size=4):
313    '''
314    Encode the number val and the extra data in the extra dict as an fvwi. See
315    decode_tbs above.
316    '''
317    flags = 0
318    for flag in extra:
319        flags |= flag
320    ans = encode_fvwi(val, flags, flag_size=flag_size)
321
322    if 0b0010 in extra:
323        ans += encint(extra[0b0010])
324    if 0b0100 in extra:
325        ans += bytes(bytearray([extra[0b0100]]))
326    if 0b0001 in extra:
327        ans += encint(extra[0b0001])
328    return ans
329
330
331def utf8_text(text):
332    '''
333    Convert a possibly null string to utf-8 bytes, guaranteeing to return a non
334    empty, normalized bytestring.
335    '''
336    if text and text.strip():
337        text = text.strip()
338        if not isinstance(text, str):
339            text = text.decode('utf-8', 'replace')
340        text = normalize(text).encode('utf-8')
341    else:
342        text = _('Unknown').encode('utf-8')
343    return text
344
345
346def align_block(raw, multiple=4, pad=b'\0'):
347    '''
348    Return raw with enough pad bytes append to ensure its length is a multiple
349    of 4.
350    '''
351    extra = len(raw) % multiple
352    if extra == 0:
353        return raw
354    return raw + pad*(multiple - extra)
355
356
357def detect_periodical(toc, log=None):
358    '''
359    Detect if the TOC object toc contains a periodical that conforms to the
360    structure required by kindlegen to generate a periodical.
361    '''
362    if toc.count() < 1 or not toc[0].klass == 'periodical':
363        return False
364    for node in toc.iterdescendants():
365        if node.depth() == 1 and node.klass != 'article':
366            if log is not None:
367                log.debug(
368                'Not a periodical: Deepest node does not have '
369                'class="article"')
370            return False
371        if node.depth() == 2 and node.klass != 'section':
372            if log is not None:
373                log.debug(
374                'Not a periodical: Second deepest node does not have'
375                ' class="section"')
376            return False
377        if node.depth() == 3 and node.klass != 'periodical':
378            if log is not None:
379                log.debug('Not a periodical: Third deepest node'
380                    ' does not have class="periodical"')
381            return False
382        if node.depth() > 3:
383            if log is not None:
384                log.debug('Not a periodical: Has nodes of depth > 3')
385            return False
386    return True
387
388
389def count_set_bits(num):
390    if num < 0:
391        num = -num
392    ans = 0
393    while num > 0:
394        ans += (num & 0b1)
395        num >>= 1
396    return ans
397
398
399def to_base(num, base=32, min_num_digits=None):
400    digits = string.digits + string.ascii_uppercase
401    sign = 1 if num >= 0 else -1
402    if num == 0:
403        return ('0' if min_num_digits is None else '0'*min_num_digits)
404    num *= sign
405    ans = []
406    while num:
407        ans.append(digits[(num % base)])
408        num //= base
409    if min_num_digits is not None and len(ans) < min_num_digits:
410        ans.extend('0'*(min_num_digits - len(ans)))
411    if sign < 0:
412        ans.append('-')
413    ans.reverse()
414    return ''.join(ans)
415
416
417def mobify_image(data):
418    'Convert PNG images to GIF as the idiotic Kindle cannot display some PNG'
419    fmt = what(None, data)
420    if fmt == 'png':
421        data = png_data_to_gif_data(data)
422    return data
423
424# Font records {{{
425
426
427def read_font_record(data, extent=1040):
428    '''
429    Return the font encoded in the MOBI FONT record represented by data.
430    The return value in a dict with fields raw_data, font_data, err, ext,
431    headers.
432
433    :param extent: The number of obfuscated bytes. So far I have only
434    encountered files with 1040 obfuscated bytes. If you encounter an
435    obfuscated record for which this function fails, try different extent
436    values (easily automated).
437
438    raw_data is the raw data in the font record
439    font_data is the decoded font_data or None if an error occurred
440    err is not None if some error occurred
441    ext is the font type (ttf for TrueType, dat for unknown and failed if an
442    error occurred)
443    headers is the list of decoded headers from the font record or None if
444    decoding failed
445    '''
446    # Format:
447    # bytes  0 -  3:  'FONT'
448    # bytes  4 -  7:  Uncompressed size
449    # bytes  8 - 11:  flags
450    #                   bit 1 - zlib compression
451    #                   bit 2 - XOR obfuscated
452    # bytes 12 - 15:  offset to start of compressed data
453    # bytes 16 - 19:  length of XOR string
454    # bytes 19 - 23:  offset to start of XOR data
455    # The zlib compressed data begins with 2 bytes of header and
456    # has 4 bytes of checksum at the end
457    ans = {'raw_data':data, 'font_data':None, 'err':None, 'ext':'failed',
458            'headers':None, 'encrypted':False}
459
460    try:
461        usize, flags, dstart, xor_len, xor_start = struct.unpack_from(
462                b'>LLLLL', data, 4)
463    except:
464        ans['err'] = 'Failed to read font record header fields'
465        return ans
466    font_data = data[dstart:]
467    ans['headers'] = {'usize':usize, 'flags':bin(flags), 'xor_len':xor_len,
468            'xor_start':xor_start, 'dstart':dstart}
469
470    if flags & 0b10:
471        # De-obfuscate the data
472        key = bytearray(data[xor_start:xor_start+xor_len])
473        buf = bytearray(font_data)
474        extent = len(font_data) if extent is None else extent
475        extent = min(extent, len(font_data))
476
477        for n in range(extent):
478            buf[n] ^= key[n%xor_len]  # XOR of buf and key
479
480        font_data = bytes(buf)
481        ans['encrypted'] = True
482
483    if flags & 0b1:
484        # ZLIB compressed data
485        try:
486            font_data = zlib.decompress(font_data)
487        except Exception as e:
488            ans['err'] = 'Failed to zlib decompress font data (%s)'%e
489            return ans
490
491        if len(font_data) != usize:
492            ans['err'] = 'Uncompressed font size mismatch'
493            return ans
494
495    ans['font_data'] = font_data
496    sig = font_data[:4]
497    ans['ext'] = ('ttf' if sig in {b'\0\1\0\0', b'true', b'ttcf'}
498                    else 'otf' if sig == b'OTTO' else 'dat')
499
500    return ans
501
502
503def write_font_record(data, obfuscate=True, compress=True):
504    '''
505    Write the ttf/otf font represented by data into a font record. See
506    read_font_record() for details on the format of the record.
507    '''
508
509    flags = 0
510    key_len = 20
511    usize = len(data)
512    xor_key = b''
513    if compress:
514        flags |= 0b1
515        data = zlib.compress(data, 9)
516    if obfuscate and len(data) >= 1040:
517        flags |= 0b10
518        xor_key = os.urandom(key_len)
519        key = bytearray(xor_key)
520        data = bytearray(data)
521        for i in range(1040):
522            data[i] ^= key[i%key_len]
523        data = bytes(data)
524
525    key_start = struct.calcsize(b'>5L') + 4
526    data_start = key_start + len(xor_key)
527
528    header = b'FONT' + struct.pack(b'>5L', usize, flags, data_start,
529            len(xor_key), key_start)
530
531    return header + xor_key + data
532
533# }}}
534
535
536def create_text_record(text):
537    '''
538    Return a Palmdoc record of size RECORD_SIZE from the text file object.
539    In case the record ends in the middle of a multibyte character return
540    the overlap as well.
541
542    Returns data, overlap: where both are byte strings. overlap is the
543    extra bytes needed to complete the truncated multibyte character.
544    '''
545    opos = text.tell()
546    text.seek(0, 2)
547    # npos is the position of the next record
548    npos = min((opos + RECORD_SIZE, text.tell()))
549    # Number of bytes from the next record needed to complete the last
550    # character in this record
551    extra = 0
552
553    last = b''
554    while not last.decode('utf-8', 'ignore'):
555        # last contains no valid utf-8 characters
556        size = len(last) + 1
557        text.seek(npos - size)
558        last = text.read(size)
559
560    # last now has one valid utf-8 char and possibly some bytes that belong
561    # to a truncated char
562
563    try:
564        last.decode('utf-8', 'strict')
565    except UnicodeDecodeError:
566        # There are some truncated bytes in last
567        prev = len(last)
568        while True:
569            text.seek(npos - prev)
570            last = text.read(len(last) + 1)
571            try:
572                last.decode('utf-8')
573            except UnicodeDecodeError:
574                pass
575            else:
576                break
577        extra = len(last) - prev
578
579    text.seek(opos)
580    data = text.read(RECORD_SIZE)
581    overlap = text.read(extra)
582    text.seek(npos)
583
584    return data, overlap
585
586
587class CNCX:  # {{{
588
589    '''
590    Create the CNCX records. These are records containing all the strings from
591    an index. Each record is of the form: <vwi string size><utf-8 encoded
592    string>
593    '''
594
595    MAX_STRING_LENGTH = 500
596
597    def __init__(self, strings=()):
598        self.strings = OrderedDict((s, 0) for s in strings)
599
600        self.records = []
601        offset = 0
602        buf = BytesIO()
603        RECORD_LIMIT = 0x10000 - 1024  # kindlegen appears to use 1024, PDB limit is 0x10000
604        for key in self.strings:
605            utf8 = utf8_text(key[:self.MAX_STRING_LENGTH])
606            l = len(utf8)
607            sz_bytes = encint(l)
608            raw = sz_bytes + utf8
609            if buf.tell() + len(raw) > RECORD_LIMIT:
610                self.records.append(align_block(buf.getvalue()))
611                buf.seek(0), buf.truncate(0)
612                offset = len(self.records) * 0x10000
613            buf.write(raw)
614            self.strings[key] = offset
615            offset += len(raw)
616
617        val = buf.getvalue()
618        if val:
619            self.records.append(align_block(val))
620
621    def __getitem__(self, string):
622        return self.strings[string]
623
624    def __bool__(self):
625        return bool(self.records)
626    __nonzero__ = __bool__
627
628    def __len__(self):
629        return len(self.records)
630
631# }}}
632
633
634def is_guide_ref_start(ref):
635    return (ref.title.lower() == 'start' or
636            (ref.type and ref.type.lower() in {'start',
637                    'other.start', 'text'}))
638
639
640def convert_color_for_font_tag(val):
641    rgba = parse_color_string(str(val or ''))
642    if rgba is None or rgba == 'currentColor':
643        return str(val)
644    clamp = lambda x: min(x, max(0, x), 1)
645    rgb = map(clamp, rgba[:3])
646    return '#' + ''.join(map(lambda x:'%02x' % int(x * 255), rgb))
647