2# vim:fileencoding=utf-8
3# License: GPLv3 Copyright: 2015, Kovid Goyal <kovid at kovidgoyal.net>
6import apsw
7import json
8import os
9import re
10from functools import lru_cache
11from threading import RLock
13from calibre import as_unicode
14from calibre.constants import config_dir
15from calibre.utils.config import from_json, to_json
16from polyglot.builtins import iteritems
19def as_json(data):
20    return json.dumps(data, ensure_ascii=False, default=to_json)
23def load_json(raw):
24    try:
25        return json.loads(raw, object_hook=from_json)
26    except Exception:
27        return {}
31def parse_restriction(raw):
32    r = load_json(raw)
33    if not isinstance(r, dict):
34        r = {}
35    lr = r.get('library_restrictions', {})
36    if not isinstance(lr, dict):
37        lr = {}
38    r['allowed_library_names'] = frozenset(map(lambda x: x.lower(), r.get('allowed_library_names', ())))
39    r['blocked_library_names'] = frozenset(map(lambda x: x.lower(), r.get('blocked_library_names', ())))
40    r['library_restrictions'] = {k.lower(): v or '' for k, v in iteritems(lr)}
41    return r
44def serialize_restriction(r):
45    ans = {}
46    for x in 'allowed_library_names blocked_library_names'.split():
47        v = r.get(x)
48        if v:
49            ans[x] = list(v)
50    ans['library_restrictions'] = {l.lower(): v or '' for l, v in iteritems(r.get('library_restrictions', {}))}
51    return json.dumps(ans)
54def validate_username(username):
55    if re.sub(r'[-a-zA-Z_0-9 ]', '', username):
56        return _('For maximum compatibility you should use only the letters A-Z,'
57                    ' the numbers 0-9, spaces, underscores and hyphens in the username')
60def validate_password(pw):
61    if not pw:
62        return _('Empty passwords are not allowed')
63    try:
64        pw = pw.encode('ascii', 'strict')
65    except ValueError:
66        return _('The password must contain only ASCII (English) characters and symbols')
69def create_user_data(pw, readonly=False, restriction=None):
70    return {
71        'pw':pw, 'restriction':parse_restriction(restriction or '{}').copy(), 'readonly': readonly
72    }
75def connect(path, exc_class=ValueError):
76    try:
77        return apsw.Connection(path)
78    except apsw.CantOpenError as e:
79        pdir = os.path.dirname(path)
80        if os.path.isdir(pdir):
81            raise exc_class('Failed to open userdb database at {} with error: {}'.format(path, as_unicode(e)))
82        try:
83            os.makedirs(pdir)
84        except OSError as e:
85            raise exc_class('Failed to make directory for userdb database at {} with error: {}'.format(pdir, as_unicode(e)))
86        try:
87            return apsw.Connection(path)
88        except apsw.CantOpenError as e:
89            raise exc_class('Failed to open userdb database at {} with error: {}'.format(path, as_unicode(e)))
92class UserManager:
94    lock = RLock()
96    @property
97    def conn(self):
98        with self.lock:
99            if self._conn is None:
100                self._conn = connect(self.path)
101                with self._conn:
102                    c = self._conn.cursor()
103                    uv = next(c.execute('PRAGMA user_version'))[0]
104                    if uv == 0:
105                        # We have to store the unhashed password, since the digest
106                        # auth scheme requires it. (Technically, one can store
107                        # a MD5 hash of the username+realm+password, but it has to be
108                        # without salt so it is trivially brute-forceable, anyway)
109                        # timestamp stores the ISO 8601 creation timestamp in UTC.
110                        c.execute('''
111                        CREATE TABLE users (
112                            id INTEGER PRIMARY KEY,
113                            name TEXT NOT NULL,
114                            pw TEXT NOT NULL,
115                            timestamp TEXT DEFAULT CURRENT_TIMESTAMP,
116                            session_data TEXT NOT NULL DEFAULT "{}",
117                            restriction TEXT NOT NULL DEFAULT "{}",
118                            readonly TEXT NOT NULL DEFAULT "n",
119                            misc_data TEXT NOT NULL DEFAULT "{}",
120                            UNIQUE(name)
121                        );
123                        PRAGMA user_version=1;
124                        ''')
125                    c.close()
126        return self._conn
128    def __init__(self, path=None):
129        self.path = os.path.join(config_dir, 'server-users.sqlite') if path is None else path
130        self._conn = None
132    def get_session_data(self, username):
133        with self.lock:
134            for data, in self.conn.cursor().execute(
135                    'SELECT session_data FROM users WHERE name=?', (username,)):
136                return load_json(data)
137        return {}
139    def set_session_data(self, username, data):
140        with self.lock:
141            conn = self.conn
142            c = conn.cursor()
143            data = as_json(data)
144            if isinstance(data, bytes):
145                data = data.decode('utf-8')
146            c.execute('UPDATE users SET session_data=? WHERE name=?', (data, username))
148    def get(self, username):
149        ' Get password for user, or None if user does not exist '
150        with self.lock:
151            for pw, in self.conn.cursor().execute(
152                    'SELECT pw FROM users WHERE name=?', (username,)):
153                return pw
155    def has_user(self, username):
156        return self.get(username) is not None
158    def validate_username(self, username):
159        if self.has_user(username):
160            return _('The username %s already exists') % username
161        return validate_username(username)
163    def validate_password(self, pw):
164        return validate_password(pw)
166    def add_user(self, username, pw, restriction=None, readonly=False):
167        with self.lock:
168            msg = self.validate_username(username) or self.validate_password(pw)
169            if msg is not None:
170                raise ValueError(msg)
171            restriction = restriction or {}
172            self.conn.cursor().execute(
173                'INSERT INTO users (name, pw, restriction, readonly) VALUES (?, ?, ?, ?)',
174                (username, pw, serialize_restriction(restriction), ('y' if readonly else 'n')))
176    def remove_user(self, username):
177        with self.lock:
178            self.conn.cursor().execute('DELETE FROM users WHERE name=?', (username,))
179            return self.conn.changes() > 0
181    @property
182    def all_user_names(self):
183        with self.lock:
184            return {x for x, in self.conn.cursor().execute(
185                'SELECT name FROM users')}
187    @property
188    def user_data(self):
189        with self.lock:
190            ans = {}
191            for name, pw, restriction, readonly in self.conn.cursor().execute('SELECT name,pw,restriction,readonly FROM users'):
192                ans[name] = create_user_data(pw, readonly.lower() == 'y', restriction)
193        return ans
195    @user_data.setter
196    def user_data(self, users):
197        with self.lock, self.conn:
198            c = self.conn.cursor()
199            remove = self.all_user_names - set(users)
200            if remove:
201                c.executemany('DELETE FROM users WHERE name=?', [(n,) for n in remove])
202            for name, data in iteritems(users):
203                res = serialize_restriction(data['restriction'])
204                r = 'y' if data['readonly'] else 'n'
205                c.execute('UPDATE users SET pw=?, restriction=?, readonly=? WHERE name=?',
206                        (data['pw'], res, r, name))
207                if self.conn.changes() > 0:
208                    continue
209                c.execute('INSERT INTO USERS (name, pw, restriction, readonly) VALUES (?, ?, ?, ?)',
210                          (name, data['pw'], res, r))
211            self.refresh()
213    def refresh(self):
214        pass  # legacy compat
216    def is_readonly(self, username):
217        with self.lock:
218            for readonly, in self.conn.cursor().execute(
219                    'SELECT readonly FROM users WHERE name=?', (username,)):
220                return readonly == 'y'
221            return False
223    def set_readonly(self, username, value):
224        with self.lock:
225            self.conn.cursor().execute(
226                'UPDATE users SET readonly=? WHERE name=?', ('y' if value else 'n', username))
228    def change_password(self, username, pw):
229        with self.lock:
230            msg = self.validate_password(pw)
231            if msg is not None:
232                raise ValueError(msg)
233            self.conn.cursor().execute(
234                'UPDATE users SET pw=? WHERE name=?', (pw, username))
236    def restrictions(self, username):
237        with self.lock:
238            for restriction, in self.conn.cursor().execute(
239                    'SELECT restriction FROM users WHERE name=?', (username,)):
240                return parse_restriction(restriction).copy()
242    def allowed_library_names(self, username, all_library_names):
243        ' Get allowed library names for specified user from set of all library names '
244        r = self.restrictions(username)
245        if r is None:
246            return set()
247        inc = r['allowed_library_names']
248        exc = r['blocked_library_names']
250        def check(n):
251            n = n.lower()
252            return (not inc or n in inc) and n not in exc
253        return {n for n in all_library_names if check(n)}
255    def update_user_restrictions(self, username, restrictions):
256        if not isinstance(restrictions, dict):
257            raise TypeError('restrictions must be a dict')
258        with self.lock:
259            self.conn.cursor().execute(
260                'UPDATE users SET restriction=? WHERE name=?', (serialize_restriction(restrictions), username))
262    def library_restriction(self, username, library_path):
263        r = self.restrictions(username)
264        if r is None:
265            return ''
266        library_name = os.path.basename(library_path).lower()
267        return r['library_restrictions'].get(library_name) or ''