1#!/usr/local/bin/python3.8
2# vim:fileencoding=utf-8
3
4
5__license__ = 'GPL v3'
6__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
7
8import errno, socket, os, time
9from email.utils import formatdate
10from operator import itemgetter
11
12from calibre import prints
13from calibre.constants import iswindows
14from calibre.srv.errors import HTTPNotFound
15from calibre.utils.localization import get_translator
16from calibre.utils.socket_inheritance import set_socket_inherit
17from calibre.utils.logging import ThreadSafeLog
18from calibre.utils.shared_file import share_open
19from polyglot.builtins import iteritems
20from polyglot import reprlib
21from polyglot.http_cookie import SimpleCookie
22from polyglot.builtins import as_unicode
23from polyglot.urllib import parse_qs, quote as urlquote
24from polyglot.binary import as_hex_unicode as encode_name, from_hex_unicode as decode_name
25
26HTTP1  = 'HTTP/1.0'
27HTTP11 = 'HTTP/1.1'
28DESIRED_SEND_BUFFER_SIZE = 16 * 1024  # windows 7 uses an 8KB sndbuf
29encode_name, decode_name
30
31
32def http_date(timeval=None):
33    return str(formatdate(timeval=timeval, usegmt=True))
34
35
36class MultiDict(dict):  # {{{
37
38    def __setitem__(self, key, val):
39        vals = dict.get(self, key, [])
40        vals.append(val)
41        dict.__setitem__(self, key, vals)
42
43    def __getitem__(self, key):
44        return dict.__getitem__(self, key)[-1]
45
46    @staticmethod
47    def create_from_query_string(qs):
48        ans = MultiDict()
49        qs = as_unicode(qs)
50        for k, v in iteritems(parse_qs(qs, keep_blank_values=True)):
51            dict.__setitem__(ans, as_unicode(k), [as_unicode(x) for x in v])
52        return ans
53
54    def update_from_listdict(self, ld):
55        for key, values in iteritems(ld):
56            for val in values:
57                self[key] = val
58
59    def items(self, duplicates=True):
60        f = dict.items
61        for k, v in f(self):
62            if duplicates:
63                for x in v:
64                    yield k, x
65            else:
66                yield k, v[-1]
67    iteritems = items
68
69    def values(self, duplicates=True):
70        f = dict.values
71        for v in f(self):
72            if duplicates:
73                yield from v
74            else:
75                yield v[-1]
76    itervalues = values
77
78    def set(self, key, val, replace_all=False):
79        if replace_all:
80            dict.__setitem__(self, key, [val])
81        else:
82            self[key] = val
83
84    def get(self, key, default=None, all=False):
85        if all:
86            try:
87                return dict.__getitem__(self, key)
88            except KeyError:
89                return []
90        try:
91            return self.__getitem__(key)
92        except KeyError:
93            return default
94
95    def pop(self, key, default=None, all=False):
96        ans = dict.pop(self, key, default)
97        if ans is default:
98            return [] if all else default
99        return ans if all else ans[-1]
100
101    def __repr__(self):
102        return '{' + ', '.join('%s: %s' % (reprlib.repr(k), reprlib.repr(v)) for k, v in iteritems(self)) + '}'
103    __str__ = __unicode__ = __repr__
104
105    def pretty(self, leading_whitespace=''):
106        return leading_whitespace + ('\n' + leading_whitespace).join(
107            '%s: %s' % (k, (repr(v) if isinstance(v, bytes) else v)) for k, v in sorted(self.items(), key=itemgetter(0)))
108# }}}
109
110
111def error_codes(*errnames):
112    ''' Return error numbers for error names, ignoring non-existent names '''
113    ans = {getattr(errno, x, None) for x in errnames}
114    ans.discard(None)
115    return ans
116
117
118socket_errors_eintr = error_codes("EINTR", "WSAEINTR")
119
120socket_errors_socket_closed = error_codes(  # errors indicating a disconnected connection
121    "EPIPE",
122    "EBADF", "WSAEBADF",
123    "ENOTSOCK", "WSAENOTSOCK",
124    "ENOTCONN", "WSAENOTCONN",
125    "ESHUTDOWN", "WSAESHUTDOWN",
126    "ETIMEDOUT", "WSAETIMEDOUT",
127    "ECONNREFUSED", "WSAECONNREFUSED",
128    "ECONNRESET", "WSAECONNRESET",
129    "ECONNABORTED", "WSAECONNABORTED",
130    "ENETRESET", "WSAENETRESET",
131    "EHOSTDOWN", "EHOSTUNREACH",
132)
133socket_errors_nonblocking = error_codes(
134    'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK')
135
136
137def start_cork(sock):
138    if hasattr(socket, 'TCP_CORK'):
139        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1)
140
141
142def stop_cork(sock):
143    if hasattr(socket, 'TCP_CORK'):
144        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0)
145
146
147def create_sock_pair():
148    '''Create socket pair. '''
149    client_sock, srv_sock = socket.socketpair()
150    set_socket_inherit(client_sock, False), set_socket_inherit(srv_sock, False)
151    return client_sock, srv_sock
152
153
154def parse_http_list(header_val):
155    """Parse lists as described by RFC 2068 Section 2.
156
157    In particular, parse comma-separated lists where the elements of
158    the list may include quoted-strings.  A quoted-string could
159    contain a comma.  A non-quoted string could have quotes in the
160    middle.  Neither commas nor quotes count if they are escaped.
161    Only double-quotes count, not single-quotes.
162    """
163    if isinstance(header_val, bytes):
164        slash, dquote, comma = b'\\",'
165        empty = b''
166    else:
167        slash, dquote, comma = '\\",'
168        empty = ''
169
170    part = empty
171    escape = quote = False
172    for cur in header_val:
173        if escape:
174            part += cur
175            escape = False
176            continue
177        if quote:
178            if cur == slash:
179                escape = True
180                continue
181            elif cur == dquote:
182                quote = False
183            part += cur
184            continue
185
186        if cur == comma:
187            yield part.strip()
188            part = empty
189            continue
190
191        if cur == dquote:
192            quote = True
193
194        part += cur
195
196    if part:
197        yield part.strip()
198
199
200def parse_http_dict(header_val):
201    'Parse an HTTP comma separated header with items of the form a=1, b="xxx" into a dictionary'
202    if not header_val:
203        return {}
204    ans = {}
205    sep, dquote = b'="' if isinstance(header_val, bytes) else '="'
206    for item in parse_http_list(header_val):
207        k, v = item.partition(sep)[::2]
208        if k:
209            if v.startswith(dquote) and v.endswith(dquote):
210                v = v[1:-1]
211            ans[k] = v
212    return ans
213
214
215def sort_q_values(header_val):
216    'Get sorted items from an HTTP header of type: a;q=0.5, b;q=0.7...'
217    if not header_val:
218        return []
219
220    def item(x):
221        e, r = x.partition(';')[::2]
222        p, v = r.partition('=')[::2]
223        q = 1.0
224        if p == 'q' and v:
225            try:
226                q = max(0.0, min(1.0, float(v.strip())))
227            except Exception:
228                pass
229        return e.strip(), q
230    return tuple(map(itemgetter(0), sorted(map(item, parse_http_list(header_val)), key=itemgetter(1), reverse=True)))
231
232
233def eintr_retry_call(func, *args, **kwargs):
234    while True:
235        try:
236            return func(*args, **kwargs)
237        except OSError as e:
238            if getattr(e, 'errno', None) in socket_errors_eintr:
239                continue
240            raise
241
242
243def get_translator_for_lang(cache, bcp_47_code):
244    try:
245        return cache[bcp_47_code]
246    except KeyError:
247        pass
248    cache[bcp_47_code] = ans = get_translator(bcp_47_code)
249    return ans
250
251
252def encode_path(*components):
253    'Encode the path specified as a list of path components using URL encoding'
254    return '/' + '/'.join(urlquote(x.encode('utf-8'), '') for x in components)
255
256
257class Cookie(SimpleCookie):
258
259    def _BaseCookie__set(self, key, real_value, coded_value):
260        return SimpleCookie._BaseCookie__set(self, key, real_value, coded_value)
261
262
263def custom_fields_to_display(db):
264    return frozenset(db.field_metadata.ignorable_field_keys())
265
266# Logging {{{
267
268
269class ServerLog(ThreadSafeLog):
270    exception_traceback_level = ThreadSafeLog.WARN
271
272
273class RotatingStream:
274
275    def __init__(self, filename, max_size=None, history=5):
276        self.filename, self.history, self.max_size = filename, history, max_size
277        if iswindows:
278            self.filename = '\\\\?\\' + os.path.abspath(self.filename)
279        self.set_output()
280
281    def set_output(self):
282        if iswindows:
283            self.stream = share_open(self.filename, 'a', newline='')
284        else:
285            # see https://bugs.python.org/issue27805
286            self.stream = open(os.open(self.filename, os.O_WRONLY|os.O_APPEND|os.O_CREAT|os.O_CLOEXEC), 'w')
287        try:
288            self.stream.tell()
289        except OSError:
290            # Happens if filename is /dev/stdout for example
291            self.max_size = None
292
293    def flush(self):
294        self.stream.flush()
295
296    def prints(self, level, *args, **kwargs):
297        kwargs['file'] = self.stream
298        prints(*args, **kwargs)
299        self.rollover()
300
301    def rename(self, src, dest):
302        try:
303            if iswindows:
304                from calibre_extensions import winutil
305                winutil.move_file(src, dest)
306            else:
307                os.rename(src, dest)
308        except OSError as e:
309            if e.errno != errno.ENOENT:  # the source of the rename does not exist
310                raise
311
312    def rollover(self):
313        if not self.max_size or self.stream.tell() <= self.max_size:
314            return
315        self.stream.close()
316        for i in range(self.history - 1, 0, -1):
317            src, dest = '%s.%d' % (self.filename, i), '%s.%d' % (self.filename, i+1)
318            self.rename(src, dest)
319        self.rename(self.filename, '%s.%d' % (self.filename, 1))
320        self.set_output()
321
322    def clear(self):
323        if self.filename in ('/dev/stdout', '/dev/stderr'):
324            return
325        self.stream.close()
326        failed = {}
327        try:
328            os.remove(self.filename)
329        except OSError as e:
330            failed[self.filename] = e
331        import glob
332        for f in glob.glob(self.filename + '.*'):
333            try:
334                os.remove(f)
335            except OSError as e:
336                failed[f] = e
337        self.set_output()
338        return failed
339
340
341class RotatingLog(ServerLog):
342
343    def __init__(self, filename, max_size=None, history=5):
344        ServerLog.__init__(self)
345        self.outputs = [RotatingStream(filename, max_size, history)]
346
347    def flush(self):
348        for o in self.outputs:
349            o.flush()
350# }}}
351
352
353class HandleInterrupt:  # {{{
354
355    # On windows socket functions like accept(), recv(), send() are not
356    # interrupted by a Ctrl-C in the console. So to make Ctrl-C work we have to
357    # use this special context manager. See the echo server example at the
358    # bottom of srv/loop.py for how to use it.
359
360    def __init__(self, action):
361        if not iswindows:
362            return  # Interrupts work fine on POSIX
363        self.action = action
364        from ctypes import WINFUNCTYPE, windll
365        from ctypes.wintypes import BOOL, DWORD
366
367        kernel32 = windll.LoadLibrary('kernel32')
368
369        # <http://msdn.microsoft.com/en-us/library/ms686016.aspx>
370        PHANDLER_ROUTINE = WINFUNCTYPE(BOOL, DWORD)
371        self.SetConsoleCtrlHandler = kernel32.SetConsoleCtrlHandler
372        self.SetConsoleCtrlHandler.argtypes = (PHANDLER_ROUTINE, BOOL)
373        self.SetConsoleCtrlHandler.restype = BOOL
374
375        @PHANDLER_ROUTINE
376        def handle(event):
377            if event == 0:  # CTRL_C_EVENT
378                if self.action is not None:
379                    self.action()
380                    self.action = None
381                    return 1
382            return 0
383        self.handle = handle
384
385    def __enter__(self):
386        if iswindows:
387            if self.SetConsoleCtrlHandler(self.handle, 1) == 0:
388                import ctypes
389                raise ctypes.WinError()
390
391    def __exit__(self, *args):
392        if iswindows:
393            if self.SetConsoleCtrlHandler(self.handle, 0) == 0:
394                import ctypes
395                raise ctypes.WinError()
396# }}}
397
398
399class Accumulator:  # {{{
400
401    'Optimized replacement for BytesIO when the usage pattern is many writes followed by a single getvalue()'
402
403    def __init__(self):
404        self._buf = []
405        self.total_length = 0
406
407    def append(self, b):
408        self._buf.append(b)
409        self.total_length += len(b)
410
411    def getvalue(self):
412        ans = b''.join(self._buf)
413        self._buf = []
414        self.total_length = 0
415        return ans
416# }}}
417
418
419def get_db(ctx, rd, library_id):
420    db = ctx.get_library(rd, library_id)
421    if db is None:
422        raise HTTPNotFound('Library %r not found' % library_id)
423    return db
424
425
426def get_library_data(ctx, rd, strict_library_id=False):
427    library_id = rd.query.get('library_id')
428    library_map, default_library = ctx.library_info(rd)
429    if library_id not in library_map:
430        if strict_library_id and library_id:
431            raise HTTPNotFound('No library with id: {}'.format(library_id))
432        library_id = default_library
433    db = get_db(ctx, rd, library_id)
434    return db, library_id, library_map, default_library
435
436
437class Offsets:
438    'Calculate offsets for a paginated view'
439
440    def __init__(self, offset, delta, total):
441        if offset < 0:
442            offset = 0
443        if offset >= total:
444            raise HTTPNotFound('Invalid offset: %r'%offset)
445        last_allowed_index = total - 1
446        last_current_index = offset + delta - 1
447        self.slice_upper_bound = offset+delta
448        self.offset = offset
449        self.next_offset = last_current_index + 1
450        if self.next_offset > last_allowed_index:
451            self.next_offset = -1
452        self.previous_offset = self.offset - delta
453        if self.previous_offset < 0:
454            self.previous_offset = 0
455        self.last_offset = last_allowed_index - delta
456        if self.last_offset < 0:
457            self.last_offset = 0
458
459
460_use_roman = None
461
462
463def get_use_roman():
464    global _use_roman
465    if _use_roman is None:
466        from calibre.gui2 import config
467        _use_roman = config['use_roman_numerals_for_series_number']
468    return _use_roman
469
470
471def fast_now_strftime(fmt):
472    return as_unicode(time.strftime(fmt), errors='replace')
473