1#!/usr/local/bin/python3.8
2
3# TODO:
4# FILES table: OBSOLETE column
5#     OBSOLETE says that this file was replaced by a newer version, for example checkrels may want to create a new file with only part of the output.
6#     Should this be a pointer to the file that replaced the obsolete one? How to signal a file that is obsolete, but not replaced by anything?
7#     If one file is replaced by several (say, due to a data corruption in the middle), we need a 1:n relationship. If several files are replaced by one (merge),
8#     we need n:1. What do? Do we really want an n:n relationship here? Disallow fragmenting files, or maybe simply not track it in the DB if we do?
9# FILES table: CHECKSUM column
10#     We need a fast check that the information stored in the DB still accurately reflects the file system contents. The test should also warn about files in upload/ which are not listed in DB
11
12
13# Make Python 2.7 use the print() syntax from Python 3
14from __future__ import print_function
15
16import sys
17import sqlite3
18try:
19    import mysql
20    import mysql.connector
21    HAVE_MYSQL=True
22except ImportError:
23    HAVE_MYSQL=False
24
25import threading
26import traceback
27import collections
28import abc
29from datetime import datetime
30import re
31import time
32
33if re.search("^/", "@CMAKE_INSTALL_PREFIX@/@LIBSUFFIX@"):
34    sys.path.append("@CMAKE_INSTALL_PREFIX@/@LIBSUFFIX@")
35
36from workunit import Workunit
37if sys.version_info.major == 3:
38    from queue import Queue
39else:
40    from Queue import Queue
41import patterns
42import cadologger
43import logging
44from collections.abc import MutableMapping
45from collections.abc import Mapping
46
47DEBUG = 1
48exclusive_transaction = [None, None]
49
50DEFERRED = object()
51IMMEDIATE = object()
52EXCLUSIVE = object()
53
54logger = logging.getLogger("Database")
55logger.setLevel(logging.NOTSET)
56
57
58PRINTED_CANCELLED_WARNING = False
59
60def join3(l, pre=None, post=None, sep=", "):
61    """
62    If any parameter is None, it is interpreted as the empty string
63    >>> join3 ( ('a'), pre="+", post="-", sep=", ")
64    '+a-'
65    >>> join3 ( ('a', 'b'), pre="+", post="-", sep=", ")
66    '+a-, +b-'
67    >>> join3 ( ('a', 'b'))
68    'a, b'
69    >>> join3 ( ('a', 'b', 'c'), pre="+", post="-", sep=", ")
70    '+a-, +b-, +c-'
71    """
72    if pre is None:
73        pre = ""
74    if post is None:
75        post = ""
76    if sep is None:
77        sep = "";
78    return sep.join([pre + k + post for k in l])
79
80def dict_join3(d, sep=None, op=None, pre=None, post=None):
81    """
82    If any parameter is None, it is interpreted as the empty string
83    >>> t = dict_join3 ( {"a": "1", "b": "2"}, sep=",", op="=", pre="-", post="+")
84    >>> t == '-a=1+,-b=2+' or t == '-b=2+,-a=1+'
85    True
86    """
87    if pre is None:
88        pre = ""
89    if post is None:
90        post = ""
91    if sep is None:
92        sep = "";
93    if op is None:
94        op = ""
95    return sep.join([pre + op.join(k) + post for k in d.items()])
96
97def conn_commit(conn):
98    logger.transaction("Commit on connection %d", id(conn))
99    if DEBUG > 1:
100        if not exclusive_transaction[0] is None and not conn is exclusive_transaction[0]:
101            logger.warning("Commit on connection %d, but exclusive lock was on %d", id(conn), id(exclusive_transaction[0]))
102        exclusive_transaction[0] = None
103        exclusive_transaction[1] = None
104    conn.commit()
105
106def conn_close(conn):
107    # I'm really having difficulties here. I can't see what's going on.
108    # Sometimes I have an uncaught exception popping up.
109    #target = 92800609832959449330691138186
110    #log(target) = 32359472153599817010011705
111    #Warning:Database: Connection 140584385754280 being closed while in transaction
112    #Exception ignored in: <bound method WuAccess.__del__ of <wudb.WuAccess object at 0x7fdc5a5fb470>>
113    #Traceback (most recent call last):
114    #  File "/home/thome/NFS/cado/scripts/cadofactor/wudb.py", line 1128, in __del__
115    #  File "/home/thome/NFS/cado/scripts/cadofactor/wudb.py", line 107, in conn_close
116    #  File "/usr/lib/python3.5/logging/__init__.py", line 1292, in warning
117    #  File "/usr/lib/python3.5/logging/__init__.py", line 1416, in _log
118    #  File "/usr/lib/python3.5/logging/__init__.py", line 1426, in handle
119    #  File "/usr/lib/python3.5/logging/__init__.py", line 1488, in callHandlers
120    #  File "/usr/lib/python3.5/logging/__init__.py", line 856, in handle
121    #  File "/usr/lib/python3.5/logging/__init__.py", line 1048, in emit
122    #  File "/usr/lib/python3.5/logging/__init__.py", line 1038, in _open
123    #NameError: name 'open' is not defined
124    #
125    try:
126        logger.transaction("Closing connection %d", id(conn))
127        if conn.in_transaction:
128            logger.warning("Connection %d being closed while in transaction", id(conn))
129        conn.close()
130    except:
131        pass
132
133
134# Dummy class for defining "constants" with reverse lookup
135STATUS_NAMES = ["AVAILABLE", "ASSIGNED", "NEED_RESUBMIT", "RECEIVED_OK",
136         "RECEIVED_ERROR", "VERIFIED_OK", "VERIFIED_ERROR", "CANCELLED"]
137STATUS_VALUES = range(len(STATUS_NAMES))
138WuStatusBase = collections.namedtuple("WuStatusBase", STATUS_NAMES)
139class WuStatusClass(WuStatusBase):
140    def check(self, status):
141        assert status in self
142    def get_name(self, status):
143        self.check(status)
144        return STATUS_NAMES[status]
145
146WuStatus = WuStatusClass(*STATUS_VALUES)
147
148
149def check_tablename(name):
150    """ Test whether name is a valid SQL table name.
151
152    Raise an exception if it isn't.
153    """
154    no_ = name.replace("_", "")
155    if not no_[0].isalpha() or not no_[1:].isalnum():
156        raise Exception("%s is not valid for an SQL table name" % name)
157
158# If we try to update the status in any way other than progressive
159# (AVAILABLE -> ASSIGNED -> ...), we raise this exception
160class StatusUpdateError(Exception):
161    pass
162
163# I wish I knew how to make that inherit from a template argument (which
164# would be sqlite3.Cursor or mysql.Cursor). I'm having difficulties to
165# grok that syntax though, so let's stay simple and stupid. We'll have a
166# *member* which is the cursor object, and so be it.
167class CursorWrapperBase(object,metaclass=abc.ABCMeta):
168    """ This class represents a DB cursor and provides convenience functions
169        around SQL queries. In particular it is meant to provide an
170        (1) an interface to SQL functionality via method calls with parameters,
171        and
172        (2) hiding some particularities of the SQL variant of the underlying
173            DBMS as far as possible """
174
175    # This is used in where queries; it converts from named arguments such as
176    # "eq" to a binary operator such as "="
177    name_to_operator = {"lt": "<", "le": "<=", "eq": "=", "ge": ">=", "gt" : ">", "ne": "!=", "like": "like"}
178    @abc.abstractproperty
179    def cursor(self):
180        pass
181
182    @abc.abstractproperty
183    def connection(self):
184        pass
185
186    # override in the derived cursor class if needed
187    @property
188    def _string_translations(self):
189        return []
190
191    # override in the derived cursor class if needed
192    def translations(self, x):
193        if type(x) == tuple:
194            return tuple([self.translations(u) for u in x])
195        elif type(x) == list:
196            return [self.translations(u) for u in x]
197        else:
198            v=x
199            for a,b in self._string_translations:
200                v,nrepl=re.subn(a, b, v)
201            return v
202
203    # override in the derived cursor class if needed
204    @property
205    def parameter_auto_increment(self):
206        return "?"
207
208    def __init__(self):
209        pass
210
211    def in_transaction(self):
212        return self.connection.in_transaction
213
214    @staticmethod
215    def _without_None(d):
216        """ Return a copy of the dictionary d, but without entries whose values
217            are None """
218        return {k[0]:k[1] for k in d.items() if k[1] is not None}
219
220    @staticmethod
221    def as_string(d):
222        if d is None:
223            return ""
224        else:
225            return ", " + dict_join3(d, sep=", ", op=" AS ")
226
227    def _where_str(self, name, **args):
228        where = ""
229        values = []
230        qm=self.parameter_auto_increment
231        for opname in args:
232            if args[opname] is None:
233                continue
234            if where == "":
235                where = " " + name + " "
236            else:
237                where = where + " AND "
238            where = where + join3(args[opname].keys(),
239                        post=" " + self.name_to_operator[opname] + " " + qm,
240                        sep=" AND ")
241            values = values + list(args[opname].values())
242        return (where, values)
243
244    def _exec(self, command, values=None):
245        """ Wrapper around self.execute() that prints arguments
246            for debugging and retries in case of "database locked" exception """
247
248        # FIXME: should be the caller's class name, as _exec could be
249        # called from outside this class
250        classname = self.__class__.__name__
251        parent = sys._getframe(1).f_code.co_name
252        command = self.translations(command)
253        command_str = command.replace("?", "%r")
254        if not values is None:
255            command_str = command_str % tuple(values)
256        logger.transaction("%s.%s(): connection = %s, command = %s",
257                           classname, parent, id(self.connection), command_str)
258        i = 0
259        while True:
260            try:
261                if values is None or len(values)==0:
262                    self.cursor.execute(command)
263                else:
264                    self.cursor.execute(command, values)
265                break
266            except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
267                if str(e) == "database disk image is malformed" or \
268                        str(e) == "disk I/O error":
269                    logger.critical("sqlite3 reports error accessing the database.")
270                    logger.critical("Database file may have gotten corrupted, "
271                            "or maybe filesystem does not properly support "
272                            "file locking.")
273                    raise
274                if str(e) != "database is locked":
275                   raise
276                i += 1
277                time.sleep(1) # wait for 1 second if database is locked
278                if i == 10:
279                    logger.critical("You might try 'fuser xxx.db' to see which process is locking the database")
280                    raise
281        logger.transaction("%s.%s(): connection = %s, command finished",
282                           classname, parent, id(self.connection))
283
284    def begin(self, mode=None):
285        if mode is None:
286            self._exec("BEGIN")
287        elif mode is DEFERRED:
288            self._exec("BEGIN DEFERRED")
289        elif mode is IMMEDIATE:
290            self._exec("BEGIN IMMEDIATE")
291        elif mode is EXCLUSIVE:
292            if DEBUG > 1:
293                tb = traceback.extract_stack()
294                if not exclusive_transaction == [None, None]:
295                    old_tb_str = "".join(traceback.format_list(exclusive_transaction[1]))
296                    new_tb_str = "".join(traceback.format_list(tb))
297                    logger.warning("Called Cursor.begin(EXCLUSIVE) when there was aleady an exclusive transaction %d\n%s",
298                                        id(exclusive_transaction[0]), old_tb_str)
299                    logger.warning("New transaction: %d\n%s", id(self.connection), new_tb_str)
300
301            self._exec("BEGIN EXCLUSIVE")
302
303            if DEBUG > 1:
304                assert exclusive_transaction == [None, None]
305                exclusive_transaction[0] = self.connection
306                exclusive_transaction[1] = tb
307        else:
308            raise TypeError("Invalid mode parameter: %r" % mode)
309
310    def pragma(self, prag):
311        self._exec("PRAGMA %s;" % prag)
312
313    def create_table(self, table, layout):
314        """ Creates a table with fields as described in the layout parameter """
315        command = "CREATE TABLE IF NOT EXISTS %s( %s );" % \
316                  (table, ", ".join(" ".join(k) for k in layout))
317        self._exec (command)
318
319    def create_index(self, name, table, columns):
320        # we get so many of these...
321        try:
322            """ Creates an index with fields as described in the columns list """
323            command = self.translations("CREATE INDEX IF NOT EXISTS") + " %s ON %s( %s );" % (name, table, ", ".join(columns))
324            self._exec (command)
325        except Exception as e:
326            logger.warning(e)
327            pass
328
329    def insert(self, table, d):
330        """ Insert a new entry, where d is a dictionary containing the
331            field:value pairs. Returns the row id of the newly created entry """
332        # INSERT INTO table (field_1, field_2, ..., field_n)
333        # 	VALUES (value_1, value_2, ..., value_n)
334
335        # Fields is a copy of d but with entries removed that have value None.
336        # This is done primarily to avoid having "id" listed explicitly in the
337        # INSERT statement, because the DB fills in a new value automatically
338        # if "id" is the primary key. But I guess not listing field:NULL items
339        # explicitly in an INSERT is a good thing in general
340        fields = self._without_None(d)
341        fields_str = ", ".join(fields.keys())
342
343        qm=self.parameter_auto_increment
344        sqlformat = ", ".join((qm,) * len(fields)) # sqlformat = "?, ?, ?, " ... "?"
345        command = "INSERT INTO %s( %s ) VALUES ( %s );" \
346                  % (table, fields_str, sqlformat)
347        values = list(fields.values())
348        self._exec(command, values)
349        rowid = self.lastrowid
350        return rowid
351
352    def update(self, table, d, **conditions):
353        """ Update fields of an existing entry. conditions specifies the where
354            clause to use for to update, entries in the dictionary d are the
355            fields and their values to update """
356        # UPDATE table SET column_1=value1, column2=value_2, ...,
357        # column_n=value_n WHERE column_n+1=value_n+1, ...,
358        qm=self.parameter_auto_increment
359        setstr = join3(d.keys(), post = " = " + qm, sep = ", ")
360        (wherestr, wherevalues) = self._where_str("WHERE", **conditions)
361        command = "UPDATE %s SET %s %s" % (table, setstr, wherestr)
362        values = list(d.values()) + wherevalues
363        self._exec(command, values)
364
365    def where_query(self, joinsource, col_alias=None, limit=None, order=None,
366                    **conditions):
367        # Table/Column names cannot be substituted, so include in query directly.
368        (WHERE, values) = self._where_str("WHERE", **conditions)
369        if order is None:
370            ORDER = ""
371        else:
372            if not order[1] in ("ASC", "DESC"):
373                raise Exception
374            ORDER = " ORDER BY %s %s" % (order[0], order[1])
375        if limit is None:
376            LIMIT = ""
377        else:
378            LIMIT = " LIMIT %s" % int(limit)
379        AS = self.as_string(col_alias);
380        command = "SELECT * %s FROM %s %s %s %s" \
381                  % (AS, joinsource, WHERE, ORDER, LIMIT)
382        return (command, values)
383
384    def where(self, joinsource, col_alias=None, limit=None, order=None,
385              values=[], **conditions):
386        """ Get a up to "limit" table rows (limit == 0: no limit) where
387            the key:value pairs of the dictionary "conditions" are set to the
388            same value in the database table """
389        (command, newvalues) = self.where_query(joinsource, col_alias, limit,
390                                             order, **conditions)
391        self._exec(command + ";", values + newvalues)
392
393    def count(self, joinsource, **conditions):
394        """ Count rows where the key:value pairs of the dictionary "conditions" are
395            set to the same value in the database table """
396
397        # Table/Column names cannot be substituted, so include in query directly.
398        (WHERE, values) = self._where_str("WHERE", **conditions)
399
400        command = "SELECT COUNT(*) FROM %s %s;" % (joinsource, WHERE)
401        self._exec(command, values)
402        r = self.cursor.fetchone()
403        return int(r[0])
404
405    def delete(self, table, **conditions):
406        """ Delete the rows specified by conditions """
407        (WHERE, values) = self._where_str("WHERE", **conditions)
408        command = "DELETE FROM %s %s;" % (table, WHERE)
409        self._exec(command, values)
410
411    def where_as_dict(self, joinsource, col_alias=None, limit=None,
412                      order=None, values=[], **conditions):
413        self.where(joinsource, col_alias=col_alias, limit=limit,
414                      order=order, values=values, **conditions)
415        # cursor.description is a list of lists, where the first element of
416        # each inner list is the column name
417        result = []
418        desc = [k[0] for k in self.cursor.description]
419        row = self.cursor.fetchone()
420        while row is not None:
421            # print("Cursor.where_as_dict(): row = %s" % row)
422            result.append(dict(zip(desc, row)))
423            row = self.cursor.fetchone()
424        return result
425    def execute(self, *args, **kwargs):
426        return self._exec(*args, **kwargs)
427    def fetchone(self, *args, **kwargs):
428        return self.cursor.fetchone(*args, **kwargs)
429    def close(self):
430        self.cursor.close()
431    @property
432    def lastrowid(self):
433        self.cursor.lastrowid
434
435class DB_base(object):
436    @property
437    def general_pattern(self):
438        return r"(?:db:)?(\w+)://(?:(?:(\w+)(?::(.*))?@)?(?:([\w\.]+)|\[([\d:]+)*\])(?::(\d+))?/)?(.*)$"
439    def __init__(self, uri, backend_pattern=None):
440        self.uri = uri
441        foo=re.match(self.general_pattern,uri)
442        if not foo:
443            raise ValueError("db URI %s does not match regexp %s" % (uri,self.general_pattern))
444        self.hostname=foo.group(4)
445        self.host_ipv6=False
446        if not self.hostname:
447            self.hostname=foo.group(5)
448            self.host_ipv6=True
449        self.backend=foo.group(1)
450        if backend_pattern is not None and not re.match(backend_pattern, self.backend):
451            raise ValueError("back-end type %s not supported, expected %s" % (self.backend, backend_pattern))
452        self.db_connect_args=dict(
453                user=foo.group(2),
454                password=foo.group(3),
455                host=self.hostname,
456                port=foo.group(6)
457        )
458        self.db_name=foo.group(7)
459        self.talked=False
460        # logger.info("Database URI is %s" % self.uri_without_credentials)
461    @property
462    def uri_without_credentials(self):
463        text="db:%s://" % self.backend
464        d=self.db_connect_args
465        if "host" in d:
466            if "user" in d:
467                text+="USERNAME"
468                if "password" in d:
469                    text+=":PASSWORD"
470                text+="@"
471            if self.host_ipv6:
472                text+="[%s]" % d["host"]
473            else:
474                text+=d["host"]
475            if "port" in d:
476                text+=":%s" % d["port"]
477            text+="/"
478        text+=self.db_name
479        return text
480    def advertise_connection(self):
481        if not self.talked:
482            logger.info("Opened connection to database %s" % self.db_name)
483            self.talked=True
484
485class DB_SQLite(DB_base):
486    class CursorWrapper(CursorWrapperBase):
487        @property
488        def cursor(self):
489            return self.__cursor
490        @property
491        def connection(self):
492            return self.cursor.connection
493        def __init__(self, cursor, *args, **kwargs):
494            self.__cursor=cursor
495            super().__init__(*args, **kwargs)
496    class ConnectionWrapper(sqlite3.Connection):
497        def cursor(self):
498            return DB_SQLite.CursorWrapper(super().cursor())
499        def __init__(self, *args, **kwargs):
500            super().__init__(isolation_level=None, *args, **kwargs)
501    def connect(self):
502        c=self.ConnectionWrapper(self.path)
503        self.advertise_connection()
504        return c
505    # FIXME I think that in any case the sqlite3 module ends up creating
506    # the db, no ?
507    def __init__(self, uri, create=False):
508        super().__init__(uri, backend_pattern="sqlite3?")
509        self.path = self.db_name
510
511if HAVE_MYSQL:
512    class DB_MySQL(DB_base):
513        class CursorWrapper(CursorWrapperBase):
514            @property
515            def parameter_auto_increment(self):
516                return "%s"
517            @property
518            def _string_translations(self):
519                return [
520                        ('\\bASC\\b', "AUTO_INCREMENT"),
521                        ('\\bCREATE INDEX IF NOT EXISTS\\b', "CREATE INDEX"),
522                        ('\\bBEGIN EXCLUSIVE\\b', "START TRANSACTION"),
523                        ('\\bpurge\\b', "purgetable"),
524                ]
525            @property
526            def cursor(self):
527                return self.__cursor
528            @property
529            def connection(self):
530                return self._connection
531            def __init__(self, cursor, connection=None, *args, **kwargs):
532                self._connection = connection
533                self.__cursor=cursor
534                super().__init__(*args, **kwargs)
535
536        class ConnectionWrapper(object):
537            def _reconnect_anonymous(self):
538                self._conn = mysql.connector.connect(**self._db_factory.db_connect_args)
539            def _reconnect(self):
540                self._conn = mysql.connector.connect(database=self._db_factory.db_name, **self._db_factory.db_connect_args)
541            def cursor(self):
542                # provide some retry capability. This must be done on the
543                # connection object, since reconnecting changes the
544                # connection member.
545                for i in range(10):
546                    try:
547                        c = self._conn.cursor()
548                        break
549                    except mysql.connector.errors.OperationalError as e:
550                        logger.warning("Got exception connecting to the database, retrying (#%d)" % i)
551                        if self.db:
552                            self._reconnect()
553                        else:
554                            raise
555                self._conn.commit()
556                return DB_MySQL.CursorWrapper(c, connection=self)
557            def __init__(self, db_factory, create=False):
558                self._db_factory = db_factory
559                db_name = self._db_factory.db_name
560                if create:
561                    try:
562                        self._reconnect()
563                    except mysql.connector.errors.ProgrammingError:
564                        # need to create the database first. Do it by
565                        # hand, with a connection which starts without a
566                        # database name.
567                        logger.info("Creating database %s" % db_name)
568                        self._reconnect_anonymous()
569                        cursor = self._conn.cursor()
570                        cursor.execute("CREATE DATABASE %s;" % db_name)
571                        cursor.execute("USE %s;" % db_name)
572                        cursor.execute("SET autocommit = 1")
573                        self._conn.commit()
574                else:
575                    self._reconnect()
576            def rollback(self):
577                self._conn.rollback()
578            def close(self):
579                self._conn.close()
580            def commit(self):
581                self._conn.commit()
582            @property
583            def in_transaction(self):
584                return self._conn.in_transaction
585        def connect(self, *args, **kwargs):
586            return self.ConnectionWrapper(self, *args, **kwargs)
587        def __init__(self, uri,create=False):
588            super().__init__(uri, backend_pattern="mysql")
589            self.path = None
590            if create:
591                conn=self.connect(create=True)
592                conn.close()
593
594class DBFactory(object):
595    # This class initializes the database from the supplied db uri.
596    # db:engine:[//[user[:password]@][host][:port]/][dbname][?params][#fragment]
597    def __init__(self, uri, *args, **kwargs):
598        self.uri = uri
599        self.base = None
600        error={}
601        sc=DB_base.__subclasses__()
602        for c in sc:
603            # logger.info("Trying database back-end %s (among %d)" % (c, len(sc)))
604            try:
605                self.base = c(uri, *args, **kwargs)
606                break
607            except ValueError as err:
608                error[str(c)]=err
609                pass
610        if self.base is None:
611            msg = "Cannot use database URI %s" % uri
612            msg += "\n" + "Messages received from %d backends:" % len(sc)
613            for c in error.keys():
614                msg += "\n" + "Error from %s: %s" % (c, error[c])
615            raise ValueError(msg)
616    def connect(self):
617        return self.base.connect()
618    @property
619    def uri_without_credentials(self):
620        return self.base.uri_without_credentials
621    @property
622    def path(self):
623        # TODO: remove
624        return self.base.path
625
626
627class DbTable(object):
628    """ A class template defining access methods to a database table """
629
630    @staticmethod
631    def _subdict(d, l):
632        """ Returns a dictionary of those key:value pairs of d for which key
633            exists l """
634        if d is None:
635            return None
636        return {k:d[k] for k in d.keys() if k in l}
637
638    def _get_colnames(self):
639        return [k[0] for k in self.fields]
640
641    def getname(self):
642        return self.tablename
643
644    def getpk(self):
645        return self.primarykey
646
647    def dictextract(self, d):
648        """ Return a dictionary with all those key:value pairs of d
649            for which key is in self._get_colnames() """
650        return self._subdict(d, self._get_colnames())
651
652    def create(self, cursor):
653        fields = list(self.fields)
654        if self.references:
655            # If this table references another table, we use the primary
656            # key of the referenced table as the foreign key name
657            r = self.references # referenced table
658            fk = (r.getpk(), "INTEGER", "REFERENCES %s ( %s ) " \
659                  % (r.getname(), r.getpk()))
660            fields.append(fk)
661        cursor.create_table(self.tablename, fields)
662        if self.references:
663            # We always create an index on the foreign key
664            cursor.create_index(self.tablename + "_pkindex", self.tablename,
665                                (fk[0], ))
666        for indexname in self.index:
667            # cursor.create_index(self.tablename + "_" + indexname, self.tablename, self.index[indexname])
668            try:
669                cursor.create_index(self.tablename + "_" + indexname + "_index",
670                                    self.tablename, self.index[indexname])
671            except Exception as e:
672                logger.warning(e)
673                pass
674
675    def insert(self, cursor, values, foreign=None):
676        """ Insert a new row into this table. The column:value pairs are
677            specified key:value pairs of the dictionary d.
678            The database's row id for the new entry is stored in
679            d[primarykey] """
680        d = self.dictextract(values)
681        assert self.primarykey not in d or d[self.primarykey] is None
682        # If a foreign key is specified in foreign, add it to the column
683        # that is marked as being a foreign key
684        if foreign:
685            r = self.references.primarykey
686            assert not r in d or d[r] is None
687            d[r] = foreign
688        values[self.primarykey] = cursor.insert(self.tablename, d)
689
690    def insert_list(self, cursor, values, foreign=None):
691        for v in values:
692            self.insert(cursor, v, foreign)
693
694    def update(self, cursor, d, **conditions):
695        """ Update an existing row in this table. The column:value pairs to
696            be written are specified key:value pairs of the dictionary d """
697        cursor.update(self.tablename, d, **conditions)
698
699    def delete(self, cursor, **conditions):
700        """ Delete an existing row in this table """
701        cursor.delete(self.tablename, **conditions)
702
703    def where(self, cursor, limit=None, order=None, **conditions):
704        assert order is None or order[0] in self._get_colnames()
705        return cursor.where_as_dict(self.tablename, limit=limit,
706                                    order=order, **conditions)
707
708
709class WuTable(DbTable):
710    tablename = "workunits"
711    fields = (
712        ("wurowid", "INTEGER PRIMARY KEY ASC", "UNIQUE NOT NULL"),
713        ("wuid", "VARCHAR(512)", "UNIQUE NOT NULL"),
714        ("submitter", "VARCHAR(512)", ""),
715        ("status", "INTEGER", "NOT NULL"),
716        ("wu", "TEXT", "NOT NULL"),
717        ("timecreated", "TEXT", ""),
718        ("timeassigned", "TEXT", ""),
719        ("assignedclient", "TEXT", ""),
720        ("timeresult", "TEXT", ""),
721        ("resultclient", "TEXT", ""),
722        ("errorcode", "INTEGER", ""),
723        ("failedcommand", "INTEGER", ""),
724        ("timeverified", "TEXT", ""),
725        ("retryof", "INTEGER", "REFERENCES %s" % tablename),
726        ("priority", "INTEGER", "")
727    )
728    primarykey = fields[0][0]
729    references = None
730    index = {"wuid": (fields[1][0],),
731             "submitter" : (fields[2][0],),
732             "priority" : (fields[14][0],),
733             "status" : (fields[3][0],)
734    }
735
736class FilesTable(DbTable):
737    tablename = "files"
738    fields = (
739        ("filesrowid", "INTEGER PRIMARY KEY ASC", "UNIQUE NOT NULL"),
740        ("filename", "TEXT", ""),
741        ("path", "VARCHAR(512)", "UNIQUE NOT NULL"),
742        ("type", "TEXT", ""),
743        ("command", "INTEGER", "")
744    )
745    primarykey = fields[0][0]
746    references = WuTable()
747    index = {}
748
749
750# The sqrt_factors table contains the input number to be factored. As
751# such, we must make sure that it's permitted to go at least as far as we
752# intend to go. 200 digits is definitely too small.
753class DictDbTable(DbTable):
754    fields = (
755        ("rowid", "INTEGER PRIMARY KEY ASC", "UNIQUE NOT NULL"),
756        ("kkey", "VARCHAR(300)", "UNIQUE NOT NULL"),
757        ("type", "INTEGER", "NOT NULL"),
758        ("value", "TEXT", "")
759        )
760    primarykey = fields[0][0]
761    references = None
762    def __init__(self, *args, name = None, **kwargs):
763        self.tablename = name
764        # index creation now always prepends the table name, and appends "index"
765        self.index = {"dictdb_kkey": ("kkey",)} # useful ?
766        super().__init__(*args, **kwargs)
767
768
769class DictDbAccess(MutableMapping):
770    """ A DB-backed flat dictionary.
771
772    Flat means that the value of each dictionary entry must be a type that
773    the underlying DB understands, like integers, strings, etc., but not
774    collections or other complex types.
775
776    A copy of all the data in the table is kept in memory; read accesses
777    are always served from the in-memory dict. Write accesses write through
778    to the DB.
779
780    >>> conn = DBFactory('db:sqlite3://:memory:').connect()
781    >>> d = DictDbAccess(conn, 'test')
782    >>> d == {}
783    True
784    >>> d['a'] = '1'
785    >>> d == {'a': '1'}
786    True
787    >>> d['a'] = 2
788    >>> d == {'a': 2}
789    True
790    >>> d['b'] = '3'
791    >>> d == {'a': 2, 'b': '3'}
792    True
793    >>> del(d)
794    >>> d = DictDbAccess(conn, 'test')
795    >>> d == {'a': 2, 'b': '3'}
796    True
797    >>> del(d['b'])
798    >>> d == {'a': 2}
799    True
800    >>> d.setdefault('a', '3')
801    2
802    >>> d == {'a': 2}
803    True
804    >>> d.setdefault('b', 3.0)
805    3.0
806    >>> d == {'a': 2, 'b': 3.0}
807    True
808    >>> d.setdefault(None, {'a': '3', 'c': '4'})
809    >>> d == {'a': 2, 'b': 3.0, 'c': '4'}
810    True
811    >>> d.update({'a': '3', 'd': True})
812    >>> d == {'a': '3', 'b': 3.0, 'c': '4', 'd': True}
813    True
814    >>> del(d)
815    >>> d = DictDbAccess(conn, 'test')
816    >>> d == {'a': '3', 'b': 3.0, 'c': '4', 'd': True}
817    True
818    >>> d.clear(['a', 'd'])
819    >>> d == {'b': 3.0, 'c': '4'}
820    True
821    >>> del(d)
822    >>> d = DictDbAccess(conn, 'test')
823    >>> d == {'b': 3.0, 'c': '4'}
824    True
825    >>> d.clear()
826    >>> d == {}
827    True
828    >>> del(d)
829    >>> d = DictDbAccess(conn, 'test')
830    >>> d == {}
831    True
832    """
833
834    types = (str, int, float, bool)
835
836    def __init__(self, db, name):
837        ''' Attaches to a DB table and reads values stored therein.
838
839        db can be a string giving the file name for the DB (same as for
840        sqlite3.connect()), or an open DB connection. The latter is allowed
841        primarily for making the doctest work, so we can reuse the same
842        memory-backed DB connection, but it may be useful in other contexts.
843        '''
844
845        if isinstance(db, DBFactory):
846            self._db = db
847            self._conn = db.connect()
848            self._ownconn = True
849        elif isinstance(db, str):
850            raise ValueError("unexpected: %s" % db)
851        else:
852            self._db = None
853            self._conn = db
854            self._ownconn = False
855        self._table = DictDbTable(name = name)
856        # Create an empty table if none exists
857        cursor = self.get_cursor()
858        self._table.create(cursor);
859        # Get the entries currently stored in the DB
860        self._data = self._getall()
861        cursor.close()
862
863    def get_cursor(self):
864        return self._conn.cursor()
865
866    # Implement the abstract methods defined by collections.abc.MutableMapping
867    # All but __del__ and __setitem__ are simply passed through to the self._data
868    # dictionary
869
870    def __getitem__(self, key):
871        return self._data.__getitem__(key)
872
873    def __iter__(self):
874        return self._data.__iter__()
875
876    def  __len__(self):
877        return self._data.__len__()
878
879    def __str__(self):
880        return self._data.__str__()
881
882    def __del__(self):
883        """ Close the DB connection and delete the in-memory dictionary """
884        if self._ownconn:
885            # When we shut down Python hard, e.g., in an exception, the
886            # conn_close() function object may have been destroyed already
887            # and trying to call it would raise another exception.
888            if callable(conn_close):
889                conn_close(self._conn)
890            else:
891                self._conn.close()
892
893    def __convert_value(self, row):
894        valuestr = row["value"]
895        valuetype = row["type"]
896        # Look up constructor for this type
897        typecon = self.types[int(valuetype)]
898        # Bool is handled separately as bool("False") == True
899        if typecon == bool:
900            if valuestr == "True":
901                return True
902            elif valuestr == "False":
903                return False
904            else:
905                raise ValueError("Value %s invalid for Bool type", valuestr)
906        return typecon(valuestr)
907
908    def __get_type_idx(self, value):
909        valuetype = type(value)
910        for (idx, t) in enumerate(self.types):
911            if valuetype == t:
912                return idx
913        raise TypeError("Type %s not supported" % str(valuetype))
914
915    def _getall(self):
916        """ Reads the whole table and returns it as a dict """
917        cursor = self.get_cursor()
918        rows = self._table.where(cursor)
919        cursor.close()
920        return {r["kkey"]: self.__convert_value(r) for r in rows}
921
922    def __setitem_nocommit(self, cursor, key, value):
923        """ Set dictionary key to value and update/insert into table,
924        but don't commit. Cursor must be given
925        """
926        update = {"value": str(value), "type": self.__get_type_idx(value)}
927        if key in self._data:
928            # Update the table row where column "key" equals key
929            self._table.update(cursor, update, eq={"kkey": key})
930        else:
931            # Insert a new row
932            update["kkey"] = key
933            self._table.insert(cursor, update)
934        # Update the in-memory dict
935        self._data[key] = value
936
937    def __setitem__(self, key, value):
938        """ Access by indexing, e.g., d["foo"]. Always commits """
939        cursor = self.get_cursor()
940
941        if not cursor.in_transaction:
942            cursor.begin(EXCLUSIVE)
943        self.__setitem_nocommit(cursor, key, value)
944        conn_commit(self._conn)
945
946        cursor.close()
947
948
949    def __delitem__(self, key, commit=True):
950        """ Delete a key from the dictionary """
951        cursor = self.get_cursor()
952
953        if not cursor.in_transaction:
954            cursor.begin(EXCLUSIVE)
955        self._table.delete(cursor, eq={"kkey": key})
956        if commit:
957            conn_commit(self._conn)
958
959        cursor.close()
960        del(self._data[key])
961
962    def setdefault(self, key, default = None, commit=True):
963        ''' Setdefault function that allows a mapping as input
964
965        Values from default dict are merged into self, *not* overwriting
966        existing values in self '''
967        if key is None and isinstance(default, Mapping):
968            update = {key:default[key] for key in default if not key in self}
969            if update:
970                self.update(update, commit=commit)
971            return None
972        elif not key in self:
973            self.update({key:default}, commit=commit)
974        return self._data[key]
975
976    def update(self, other, commit=True):
977        cursor = self.get_cursor()
978        if not self._conn.in_transaction:
979            cursor.begin(EXCLUSIVE)
980        for (key, value) in other.items():
981            self.__setitem_nocommit(cursor, key, value)
982        if commit:
983            conn_commit(self._conn)
984
985        cursor.close()
986
987    def clear(self, args = None, commit=True):
988        """ Overridden clear that allows removing several keys atomically """
989        cursor = self.get_cursor()
990        if not self._conn.in_transaction:
991            cursor.begin(EXCLUSIVE)
992        if args is None:
993            self._data.clear()
994            self._table.delete(cursor)
995        else:
996            for key in args:
997                del(self._data[key])
998                self._table.delete(cursor, eq={"kkey": key})
999        if commit:
1000            conn_commit(self._conn)
1001
1002        cursor.close()
1003
1004
1005class Mapper(object):
1006    """ This class translates between application objects, i.e., Python
1007        directories, and the relational data layout in an SQL DB, i.e.,
1008        one or more tables which possibly have foreign key relationships
1009        that map to hierarchical data structures. For now, only one
1010        foreign key / subdirectory."""
1011
1012    def __init__(self, table, subtables = None):
1013        self.table = table
1014        self.subtables = {}
1015        if subtables:
1016            for s in subtables.keys():
1017                self.subtables[s] = Mapper(subtables[s])
1018
1019    def __sub_dict(self, d):
1020        """ For each key "k" that has a subtable assigned in "self.subtables",
1021        pop the entry with key "k" from "d", and store it in a new directory
1022        which is returned. I.e., the directory d is separated into
1023        two parts: the part which corresponds to subtables and is the return
1024        value, and the rest which is left in the input dictionary. """
1025        sub_dict = {}
1026        for s in self.subtables.keys():
1027            # Don't store s:None entries even if they exist in d
1028            t = d.pop(s, None)
1029            if not t is None:
1030                sub_dict[s] = t
1031        return sub_dict
1032
1033    def getname(self):
1034        return self.table.getname()
1035
1036    def getpk(self):
1037        return self.table.getpk()
1038
1039    def create(self, cursor):
1040        self.table.create(cursor)
1041        for t in self.subtables.values():
1042            t.create(cursor)
1043
1044    def insert(self, cursor, wus, foreign=None):
1045        pk = self.getpk()
1046        for wu in wus:
1047            # Make copy so sub_dict does not change caller's data
1048            wuc = wu.copy()
1049            # Split off entries that refer to subtables
1050            sub_dict = self.__sub_dict(wuc)
1051            # We add the entries in wuc only if it does not have a primary
1052            # key yet. If it does have a primary key, we add only the data
1053            # for the subtables
1054            if not pk in wuc:
1055                self.table.insert(cursor, wuc, foreign=foreign)
1056                # Copy primary key into caller's data
1057                wu[pk] = wuc[pk]
1058            for subtable_name in sub_dict.keys():
1059                self.subtables[subtable_name].insert(
1060                    cursor, sub_dict[subtable_name], foreign=wu[pk])
1061
1062    def update(self, cursor, wus):
1063        pk = self.getpk()
1064        for wu in wus:
1065            assert not wu[pk] is None
1066            wuc = wu.copy()
1067            sub_dict = self.__sub_dict(wuc)
1068            rowid = wuc.pop(pk, None)
1069            if rowid:
1070                self.table.update(cursor, wuc, {wp: rowid})
1071            for s in sub.keys:
1072                self.subtables[s].update(cursor, sub_dict[s])
1073
1074    def count(self, cursor, **cond):
1075        joinsource = self.table.tablename
1076        return cursor.count(joinsource, **cond)
1077
1078    def where(self, cursor, limit = None, order = None, **cond):
1079        # We want:
1080        # SELECT * FROM (SELECT * from workunits WHERE status = 2 LIMIT 1) LEFT JOIN files USING ( wurowid );
1081        pk = self.getpk()
1082        (command, values) = cursor.where_query(self.table.tablename,
1083                                               limit=limit, **cond)
1084        joinsource = "( %s )" % command
1085        for s in self.subtables.keys():
1086            # FIXME: this probably breaks with more than 2 tables
1087            joinsource = "%s tmp LEFT JOIN %s USING ( %s )" \
1088                         % (joinsource, self.subtables[s].getname(), pk)
1089        # FIXME: don't get result rows as dict! Leave as tuple and
1090        # take them apart positionally
1091
1092        rows = cursor.where_as_dict(joinsource, order=order, values=values)
1093        wus = []
1094        for r in rows:
1095
1096            # Collapse rows with identical primary key
1097            if len(wus) == 0 or r[pk] != wus[-1][pk]:
1098                wus.append(self.table.dictextract(r))
1099                for s in self.subtables.keys():
1100                    wus[-1][s] = None
1101
1102            for (sn, sm) in self.subtables.items():
1103                spk = sm.getpk()
1104                # if there was a match on the subtable
1105                if spk in r and not r[spk] is None:
1106                    if wus[-1][sn] == None:
1107                        # If this sub-array is empty, init it
1108                        wus[-1][sn] = [sm.table.dictextract(r)]
1109                    elif r[spk] != wus[-1][sn][-1][spk]:
1110                        # If not empty, and primary key of sub-table is not
1111                        # same as in previous entry, add it
1112                        wus[-1][sn].append(sm.table.dictextract(r))
1113        return wus
1114
1115class WuAccess(object): # {
1116    """ This class maps between the WORKUNIT and FILES tables
1117        and a dictionary
1118        {"wuid": string, ..., "timeverified": string, "files": list}
1119        where list is None or a list of dictionaries of the from
1120        {"id": int, "type": int, "wuid": string, "filename": string,
1121        "path": string}
1122        Operations on instances of WuAcccess are directly carried
1123        out on the database persistent storage, i.e., they behave kind
1124        of as if the WuAccess instance were itself a persistent
1125        storage device """
1126
1127    def __init__(self, db):
1128        if isinstance(db, DBFactory):
1129            self.conn = db.connect()
1130            self._ownconn = True
1131        elif isinstance(db, str):
1132            raise ValueError("unexpected")
1133        else:
1134            self.conn = db
1135            self._ownconn = False
1136        cursor = self.get_cursor()
1137        if isinstance(cursor, DB_SQLite.CursorWrapper):
1138            cursor.pragma("foreign_keys = ON")
1139        # I'm not sure it's relevant to do commit() at this point.
1140        # self.commit()
1141        cursor.close()
1142        self.mapper = Mapper(WuTable(), {"files": FilesTable()})
1143
1144    def get_cursor(self):
1145        c = self.conn.cursor()
1146        return c
1147
1148    def __del__(self):
1149        if self._ownconn:
1150            if callable(conn_close):
1151                conn_close(self.conn)
1152            else:
1153                self.conn.close()
1154
1155    @staticmethod
1156    def to_str(wus):
1157        r = []
1158        for wu in wus:
1159            s = "Workunit %s:\n" % wu["wuid"]
1160            for (k,v) in wu.items():
1161                if k != "wuid" and k != "files":
1162                    s += "  %s: %r\n" % (k, v)
1163            if "files" in wu:
1164                s += "  Associated files:\n"
1165                if wu["files"] is None:
1166                    s += "    None\n"
1167                else:
1168                    for f in wu["files"]:
1169                        s += "    %s\n" % f
1170            r.append(s)
1171        return '\n'.join(r)
1172
1173    @staticmethod
1174    def _checkstatus(wu, status):
1175        #logger.debug("WuAccess._checkstatus(%s, %s)", wu, status)
1176        wu_status = wu["status"]
1177        if isinstance(status, collections.abc.Container):
1178            ok = wu_status in status
1179        else:
1180            ok = wu_status == status
1181        if not ok:
1182            msg = "Workunit %s has status %s (%s), expected %s (%s)" % \
1183                  (wu["wuid"], wu_status, WuStatus.get_name(wu_status),
1184                   status, WuStatus.get_name(status))
1185            if status is WuStatus.ASSIGNED and wu_status is WuStatus.CANCELLED:
1186                logger.warning ("WuAccess._checkstatus(): %s, presumably timed out", msg)
1187                raise StatusUpdateError(msg)
1188            elif status is WuStatus.ASSIGNED and wu_status is WuStatus.NEED_RESUBMIT:
1189                logger.warning ("WuAccess._checkstatus(): %s, manually expired", msg)
1190                raise StatusUpdateError(msg)
1191            else:
1192                logger.error ("WuAccess._checkstatus(): %s", msg)
1193                raise StatusUpdateError(msg)
1194
1195    # Which fields should be None for which status
1196    should_be_unset = {
1197        "errorcode": (WuStatus.AVAILABLE, WuStatus.ASSIGNED),
1198        "timeresult": (WuStatus.AVAILABLE, WuStatus.ASSIGNED),
1199        "resultclient": (WuStatus.AVAILABLE, WuStatus.ASSIGNED),
1200        "timeassigned": (WuStatus.AVAILABLE,),
1201        "assignedclient": (WuStatus.AVAILABLE,),
1202    }
1203    def check(self, data):
1204        status = data["status"]
1205        WuStatus.check(status)
1206        wu = Workunit(data["wu"])
1207        assert wu.get_id() == data["wuid"]
1208        if status == WuStatus.RECEIVED_ERROR:
1209            assert data["errorcode"] != 0
1210        if status == WuStatus.RECEIVED_OK:
1211            assert data["errorcode"] is None or data["errorcode"] == 0
1212        for field in self.should_be_unset:
1213            if status in self.should_be_unset[field]:
1214                assert data[field] is None
1215
1216    # Here come the application-visible functions that implement the
1217    # "business logic": creating a new workunit from the text of a WU file,
1218    # assigning it to a client, receiving a result for the WU, marking it as
1219    # verified, or marking it as cancelled
1220
1221    def _add_files(self, cursor, files, wuid=None, rowid=None):
1222        # Exactly one must be given
1223        assert not wuid is None or not rowid is None
1224        assert wuid is None or rowid is None
1225        # FIXME: allow selecting row to update directly via wuid, without
1226        # doing query for rowid first
1227        pk = self.mapper.getpk()
1228        if rowid is None:
1229            wu = get_by_wuid(cursor, wuid)
1230            if wu:
1231                rowid = wu[pk]
1232            else:
1233                return False
1234        colnames = ("filename", "path", "type", "command")
1235        # zipped length is that of shortest list, so "command" is optional
1236        d = (dict(zip(colnames, f)) for f in files)
1237        # These two should behave identically
1238        if True:
1239            self.mapper.insert(cursor, [{pk:rowid, "files": d},])
1240        else:
1241            self.mapper.subtables["files"].insert(cursor, d, foreign=rowid)
1242
1243    def commit(self, do_commit=True):
1244        if do_commit:
1245            conn_commit(self.conn)
1246
1247    def create_tables(self):
1248        cursor = self.get_cursor()
1249        if isinstance(cursor, DB_SQLite.CursorWrapper):
1250            cursor.pragma("journal_mode=WAL")
1251        self.mapper.create(cursor)
1252        self.commit()
1253        cursor.close()
1254
1255    def _create1(self, cursor, wutext, priority=None):
1256        d = {
1257            "wuid": Workunit(wutext).get_id(),
1258            "wu": wutext,
1259            "status": WuStatus.AVAILABLE,
1260            "timecreated": str(datetime.utcnow())
1261            }
1262        if not priority is None:
1263            d["priority"] = priority
1264        # Insert directly into wu table
1265        self.mapper.table.insert(cursor, d)
1266
1267    def create(self, wus, priority=None, commit=True):
1268        """ Create new workunits from wus which contains the texts of the
1269            workunit files """
1270        cursor = self.get_cursor()
1271        # todo restore transactions
1272        if not self.conn.in_transaction:
1273            cursor.begin(EXCLUSIVE)
1274        if isinstance(wus, str):
1275            self._create1(cursor, wus, priority)
1276        else:
1277            for wu in wus:
1278                self._create1(cursor, wu, priority)
1279        self.commit(commit)
1280        cursor.close()
1281
1282    def assign(self, clientid, commit=True, timeout_hint=None):
1283        """ Finds an available workunit and assigns it to clientid.
1284            Returns the text of the workunit, or None if no available
1285            workunit exists """
1286        cursor = self.get_cursor()
1287        if not self.conn.in_transaction:
1288            cursor.begin(EXCLUSIVE)
1289# This "priority" stuff is the root cause for the server taking time to
1290# hand out WUs when the count of available WUs drops to zero.
1291# (introduced in 90ae4beb7 -- it's an optional-and-never-used feature
1292# anyway)
1293#        r = self.mapper.table.where(cursor, limit = 1,
1294#                                    order=("priority", "DESC"),
1295#                                    eq={"status": WuStatus.AVAILABLE})
1296        r = self.mapper.table.where(cursor, limit = 1,
1297                                    eq={"status": WuStatus.AVAILABLE})
1298        assert len(r) <= 1
1299        if len(r) == 1:
1300            try:
1301                self._checkstatus(r[0], WuStatus.AVAILABLE)
1302            except StatusUpdateError:
1303                self.commit(commit)
1304                cursor.close()
1305                raise
1306            if DEBUG > 0:
1307                self.check(r[0])
1308            d = {"status": WuStatus.ASSIGNED,
1309                 "assignedclient": clientid,
1310                 "timeassigned": str(datetime.utcnow())
1311                 }
1312            pk = self.mapper.getpk()
1313            self.mapper.table.update(cursor, d, eq={pk:r[0][pk]})
1314            result = r[0]["wu"]
1315            if timeout_hint:
1316                dltext = "%d\n" % int(time.time() + int(timeout_hint))
1317                result = result + "DEADLINE " + dltext
1318
1319        else:
1320            result = None
1321
1322        self.commit(commit)
1323
1324        cursor.close()
1325        return result
1326
1327    def get_by_wuid(self, cursor, wuid):
1328        r = self.mapper.where(cursor, eq={"wuid": wuid})
1329        assert len(r) <= 1
1330        if len(r) == 1:
1331            return r[0]
1332        else:
1333            return None
1334
1335    def result(self, wuid, clientid, files, errorcode=None,
1336               failedcommand=None, commit=True):
1337        cursor = self.get_cursor()
1338        if not self.conn.in_transaction:
1339            cursor.begin(EXCLUSIVE)
1340        data = self.get_by_wuid(cursor, wuid)
1341        if data is None:
1342            self.commit(commit)
1343            cursor.close()
1344            return False
1345        try:
1346            self._checkstatus(data, WuStatus.ASSIGNED)
1347        except StatusUpdateError:
1348            self.commit(commit)
1349            cursor.close()
1350            if data["status"] == WuStatus.CANCELLED:
1351                global PRINTED_CANCELLED_WARNING
1352                if not PRINTED_CANCELLED_WARNING:
1353                    logger.warning("If workunits get cancelled due to timeout "
1354                            "even though the clients are still processing them, "
1355                            "consider increasing the tasks.wutimeout parameter or "
1356                            "decreasing the range covered in each workunit, "
1357                            "i.e., the tasks.polyselect.adrange or "
1358                            "tasks.sieve.qrange parameters.")
1359                    PRINTED_CANCELLED_WARNING = True
1360            raise
1361        if DEBUG > 0:
1362            self.check(data)
1363        d = {"resultclient": clientid,
1364             "errorcode": errorcode,
1365             "failedcommand": failedcommand,
1366             "timeresult": str(datetime.utcnow())}
1367        if errorcode is None or errorcode == 0:
1368           d["status"] = WuStatus.RECEIVED_OK
1369        else:
1370            d["status"] = WuStatus.RECEIVED_ERROR
1371        pk = self.mapper.getpk()
1372        self._add_files(cursor, files, rowid = data[pk])
1373        self.mapper.table.update(cursor, d, eq={pk:data[pk]})
1374        self.commit(commit)
1375        cursor.close()
1376        return True
1377
1378    def verification(self, wuid, ok, commit=True):
1379        cursor = self.get_cursor()
1380        if not self.conn.in_transaction:
1381            cursor.begin(EXCLUSIVE)
1382        data = self.get_by_wuid(cursor, wuid)
1383        if data is None:
1384            self.commit(commit)
1385            cursor.close()
1386            return False
1387        # FIXME: should we do the update by wuid and skip these checks?
1388        try:
1389            self._checkstatus(data, [WuStatus.RECEIVED_OK, WuStatus.RECEIVED_ERROR])
1390        except StatusUpdateError:
1391            self.commit(commit)
1392            cursor.close()
1393            raise
1394        if DEBUG > 0:
1395            self.check(data)
1396        d = {"timeverified": str(datetime.utcnow())}
1397        d["status"] = WuStatus.VERIFIED_OK if ok else WuStatus.VERIFIED_ERROR
1398        pk = self.mapper.getpk()
1399        self.mapper.table.update(cursor, d, eq={pk:data[pk]})
1400        self.commit(commit)
1401
1402        cursor.close()
1403        return True
1404
1405    def cancel(self, wuid, commit=True):
1406        self.cancel_by_condition(eq={"wuid": wuid}, commit=commit)
1407
1408    def cancel_all_available(self, commit=True):
1409        self.cancel_by_condition(eq={"status": WuStatus.AVAILABLE}, commit=commit)
1410
1411    def cancel_all_assigned(self, commit=True):
1412        self.cancel_by_condition(eq={"status": WuStatus.ASSIGNED}, commit=commit)
1413
1414    def cancel_by_condition(self, commit=True, **conditions):
1415        self.set_status(WuStatus.CANCELLED, commit=commit, **conditions)
1416
1417    def set_status(self, status, commit=True, **conditions):
1418        cursor = self.get_cursor()
1419        if not self.conn.in_transaction:
1420            cursor.begin(EXCLUSIVE)
1421        self.mapper.table.update(cursor, {"status": status}, **conditions)
1422        self.commit(commit)
1423        cursor.close()
1424
1425    def query(self, limit=None, **conditions):
1426        cursor = self.get_cursor()
1427        r = self.mapper.where(cursor, limit=limit, **conditions)
1428        cursor.close()
1429        return r
1430
1431    def count(self, **cond):
1432        cursor = self.get_cursor()
1433        count = self.mapper.count(cursor, **cond)
1434        cursor.close()
1435        return count
1436
1437    def count_available(self):
1438        return self.count(eq={"status": WuStatus.AVAILABLE})
1439
1440    def get_one_result(self):
1441        r = self.query(limit = 1, eq={"status": WuStatus.RECEIVED_OK})
1442        if not r:
1443            r = self.query(limit = 1, eq={"status": WuStatus.RECEIVED_ERROR})
1444        if not r:
1445            return None
1446        else:
1447            return r[0]
1448#}
1449
1450class WuResultMessage(metaclass=abc.ABCMeta):
1451    @abc.abstractmethod
1452    def get_wu_id(self):
1453        pass
1454    @abc.abstractmethod
1455    def get_output_files(self):
1456        pass
1457    @abc.abstractmethod
1458    def get_stdout(self, command_nr):
1459        pass
1460    @abc.abstractmethod
1461    def get_stdoutfile(self, command_nr):
1462        pass
1463    @abc.abstractmethod
1464    def get_stderr(self, command_nr):
1465        pass
1466    @abc.abstractmethod
1467    def get_stderrfile(self, command_nr):
1468        pass
1469    @abc.abstractmethod
1470    def get_exitcode(self, command_nr):
1471        pass
1472    @abc.abstractmethod
1473    def get_command_line(self, command_nr):
1474        pass
1475    @abc.abstractmethod
1476    def get_host(self):
1477        pass
1478    def _read(self, filename, data):
1479        if not filename is None:
1480            with open(filename, "rb") as inputfile:
1481                data = inputfile.read()
1482        return bytes() if data is None else data
1483    def read_stdout(self, command_nr):
1484        """ Returns the contents of stdout of command_nr as a byte string.
1485
1486        If no stdout was captured, returns the empty byte string.
1487        """
1488        return self._read(self.get_stdoutfile(command_nr),
1489                          self.get_stdout(command_nr))
1490    def read_stderr(self, command_nr):
1491        """ Like read_stdout() but for stderr """
1492        return self._read(self.get_stderrfile(command_nr),
1493                          self.get_stderr(command_nr))
1494
1495
1496class ResultInfo(WuResultMessage):
1497    def __init__(self, record):
1498        # record looks like this:
1499        # {'status': 0, 'errorcode': None, 'timeresult': None, 'wuid': 'testrun_polyselect_0-5000',
1500        #  'wurowid': 1, 'timecreated': '2013-05-23 22:28:08.333310', 'timeverified': None,
1501        #  'failedcommand': None, 'priority': None, 'wu': "WORKUNIT [..rest of workunit text...] \n",
1502        #  'assignedclient': None, 'retryof': None, 'timeassigned': None, 'resultclient': None,
1503        #  'files': None}
1504        self.record = record
1505
1506    def __str__(self):
1507        return str(self.record)
1508    def get_wu_id(self):
1509        return self.record["wuid"]
1510
1511    def get_output_files(self):
1512        """ Returns the list of output files of this workunit
1513
1514        Only files that were specified in RESULT lines appear here;
1515        automatically captured stdout and stderr does not.
1516        """
1517        if self.record["files"] is None:
1518            return []
1519        files = []
1520        for f in self.record["files"]:
1521            if f["type"] == "RESULT":
1522                files.append(f["path"])
1523        return files
1524
1525    def _get_stdio(self, filetype, command_nr):
1526        """ Get the file location of the stdout or stderr file of the
1527        command_nr-th command. Used internally.
1528        """
1529        if self.record["files"] is None:
1530            return None
1531        for f in self.record["files"]:
1532            if f["type"] == filetype and int(f["command"]) == command_nr:
1533                return f["path"]
1534        return None
1535
1536    def get_stdout(self, command_nr):
1537        # stdout is always captured into a file, not made available directly
1538        return None
1539
1540    def get_stdoutfile(self, command_nr):
1541        """ Return the path to the file that captured stdout of the
1542        command_nr-th COMMAND in the workunit, or None if there was no stdout
1543        output. Note that explicitly redirected stdout that was uploaded via
1544        RESULT does not appear here, but in get_files()
1545        """
1546        return self._get_stdio("stdout", command_nr)
1547
1548    def get_stderr(self, command_nr):
1549        # stderr is always captured into a file, not made available directly
1550        return None
1551
1552    def get_stderrfile(self, command_nr):
1553        """ Like get_stdoutfile(), but for stderr """
1554        return self._get_stdio("stderr", command_nr)
1555
1556    def get_exitcode(self, command_nr):
1557        """ Return the exit code of the command_nr-th command """
1558        if not self.record["failedcommand"] is None \
1559                and command_nr == int(self.record["failedcommand"]):
1560            return int(self.record["errorcode"])
1561        else:
1562            return 0
1563
1564    def get_command_line(self, command_nr):
1565        return None
1566
1567    def get_host(self):
1568        return self.record["resultclient"]
1569
1570
1571class DbListener(patterns.Observable):
1572    """ Class that queries the Workunit database for available results
1573    and sends them to its Observers.
1574
1575    The query is triggered by receiving a SIGUSR1 (the instance subscribes to
1576    the signal handler relay), or by calling send_result().
1577    """
1578    # FIXME: SIGUSR1 handler is not implemented
1579    def __init__(self, *args, db, **kwargs):
1580        super().__init__(*args, **kwargs)
1581        self.wuar = WuAccess(db)
1582
1583    def send_result(self):
1584        # Check for results
1585        r = self.wuar.get_one_result()
1586        if not r:
1587            return False
1588        message = ResultInfo(r)
1589        was_received = self.notifyObservers(message)
1590        if not was_received:
1591            logger.error("Result for workunit %s was not processed by any task. "
1592                         "Setting it to status CANCELLED", message.get_wu_id())
1593            self.wuar.cancel(message.get_wu_id())
1594        return was_received
1595
1596class IdMap(object):
1597    """ Identity map. Ensures that DB-backed dictionaries of the same table
1598    name are instantiated only once.
1599
1600    Problem: we should also require that the DB is identical, but file names
1601    are not a unique specifier to a file, and we allow connection objects
1602    instead of DB file name. Not clear how to test for identity, lacking
1603    support for this from the sqlite3 module API.
1604    """
1605    def __init__(self):
1606        self.db_dicts = {}
1607
1608    def make_db_dict(self, db, name):
1609        key = name
1610        if not key in self.db_dicts:
1611            self.db_dicts[key] = DictDbAccess(db, name)
1612        return self.db_dicts[key]
1613
1614# Singleton instance of IdMap
1615idmap = IdMap()
1616
1617class DbAccess(object):
1618    """ Base class that lets subclasses create DB-backed dictionaries or
1619    WuAccess instances on a database whose file name is specified in the db
1620    parameter to __init__.
1621    Meant to be used as a cooperative class; it strips the db parameter from
1622    the parameter list and remembers it in a private variable so that it can
1623    later be used to open DB connections.
1624    """
1625
1626    def __init__(self, *args, db, **kwargs):
1627        super().__init__(*args, **kwargs)
1628        self.__db = db
1629
1630    def get_db_connection(self):
1631        return self.__db.connect()
1632
1633    def get_db_filename(self):
1634        return self.__db.path
1635
1636    def get_db_uri(self):
1637        return self.__db.uri
1638
1639    def make_db_dict(self, name, connection=None):
1640        if connection is None:
1641            return idmap.make_db_dict(self.__db, name)
1642        else:
1643            return idmap.make_db_dict(connection, name)
1644
1645    def make_wu_access(self, connection=None):
1646        if connection is None:
1647            return WuAccess(self.__db)
1648        else:
1649            return WuAccess(connection)
1650
1651    def make_db_listener(self, connection=None):
1652        if connection is None:
1653            return DbListener(db=self.__db)
1654        else:
1655            return DbListener(db=connection)
1656
1657
1658class HasDbConnection(DbAccess):
1659    """ Gives sub-classes a db_connection attribute which is a database
1660    connection instance.
1661    """
1662    def __init__(self, *args, **kwargs):
1663        super().__init__(*args, **kwargs)
1664        self.db_connection = self.get_db_connection()
1665
1666
1667class UsesWorkunitDb(HasDbConnection):
1668    """ Gives sub-classes a wuar attribute which is WuAccess instance, using
1669    the sub-classes' shared database connection.
1670    """
1671    def __init__(self, *args, **kwargs):
1672        super().__init__(*args, **kwargs)
1673        self.wuar = self.make_wu_access(self.db_connection)
1674
1675
1676class DbWorker(DbAccess, threading.Thread):
1677    """Thread executing WuAccess requests from a given tasks queue"""
1678    def __init__(self, taskqueue, *args, daemon=None, **kwargs):
1679        super().__init__(*args, **kwargs)
1680        self.taskqueue = taskqueue
1681        if not daemon is None:
1682            self.daemon = daemon
1683        self.start()
1684
1685    def run(self):
1686        # One DB connection per thread. Created inside the new thread to make
1687        # sqlite happy
1688        wuar = self.make_wu_access()
1689        while True:
1690            # We expect a 4-tuple in the task queue. The elements of the tuple:
1691            # a 2-array, where element [0] receives the result of the DB call,
1692            #  and [1] is an Event variable to notify the caller when the
1693            #  result is available
1694            # fn_name, the name (as a string) of the WuAccess method to call
1695            # args, a tuple of positional arguments
1696            # kargs, a dictionary of keyword arguments
1697            (result_tuple, fn_name, args, kargs) = self.taskqueue.get()
1698            if fn_name == "terminate":
1699                break
1700            ev = result_tuple[1]
1701            # Assign to tuple in-place, so result is visible to caller.
1702            # No slice etc. here which would create a copy of the array
1703            try: result_tuple[0] = getattr(wuar, fn_name)(*args, **kargs)
1704            except Exception as e:
1705                traceback.print_exc()
1706            ev.set()
1707            self.taskqueue.task_done()
1708
1709class DbRequest(object):
1710    """ Class that represents a request to a given WuAccess function.
1711        Used mostly so that DbThreadPool's __getattr__ can return a callable
1712        that knows which of WuAccess's methods should be called by the
1713        worker thread """
1714    def __init__(self, taskqueue, func):
1715        self.taskqueue = taskqueue
1716        self.func = func
1717
1718    def do_task(self, *args, **kargs):
1719        """Add a task to the queue, wait for its completion, and return the result"""
1720        ev = threading.Event()
1721        result = [None, ev]
1722        self.taskqueue.put((result, self.func, args, kargs))
1723        ev.wait()
1724        return result[0]
1725
1726class DbThreadPool(object):
1727    """Pool of threads consuming tasks from a queue"""
1728    def __init__(self, dburi, num_threads=1):
1729        self.taskqueue = Queue(num_threads)
1730        self.pool = []
1731        for _ in range(num_threads):
1732            worker = DbWorker(self.taskqueue, daemon=True, db=dburi)
1733            self.pool.append(worker)
1734
1735    def terminate(self):
1736        for t in self.pool:
1737            self.taskqueue.put((None, "terminate", None, None))
1738        self.wait_completion
1739
1740    def wait_completion(self):
1741        """Wait for completion of all the tasks in the queue"""
1742        self.taskqueue.join()
1743
1744    def __getattr__(self, name):
1745        """ Delegate calls to methods of WuAccess to a worker thread.
1746            If the called method exists in WuAccess, creates a new
1747            DbRequest instance that remembers the name of the method that we
1748            tried to call, and returns the DbRequest instance's do_task
1749            method which will process the method call via the thread pool.
1750            We need to go through a new object's method since we cannot make
1751            the caller pass the name of the method to call to the thread pool
1752            otherwise """
1753        if hasattr(WuAccess, name):
1754            task = DbRequest(self.taskqueue, name)
1755            return task.do_task
1756        else:
1757            raise AttributeError(name)
1758
1759
1760# One entry in the WU DB, including the text with the WU contents
1761# (FILEs, COMMANDs, etc.) and info about the progress on this WU (when and
1762# to whom assigned, received, etc.)
1763
1764    # wuid is the unique wuid of the workunit
1765    # status is a status code as defined in WuStatus
1766    # data is the str containing the text of the workunit
1767    # timecreated is the string containing the date and time of when the WU was added to the db
1768    # timeassigned is the ... of when the WU was assigned to a client
1769    # assignedclient is the clientid of the client to which the WU was assigned
1770    # timeresult is the ... of when a result for this WU was received
1771    # resultclient is the clientid of the client that uploaded a result for this WU
1772    # errorcode is the exit status code of the first failed command, or 0 if none failed
1773    # timeverified is the ... of when the result was marked as verified
1774
1775
1776if __name__ == '__main__': # {
1777    import argparse
1778
1779    queries = {"avail" : ("Available workunits", {"eq": {"status": WuStatus.AVAILABLE}}),
1780               "assigned": ("Assigned workunits", {"eq": {"status": WuStatus.ASSIGNED}}),
1781               "receivedok": ("Received ok workunits", {"eq":{"status": WuStatus.RECEIVED_OK}}),
1782               "receivederr": ("Received with error workunits", {"eq": {"status": WuStatus.RECEIVED_ERROR}}),
1783               "verifiedok": ("Verified ok workunits", {"eq": {"status": WuStatus.VERIFIED_OK}}),
1784               "verifiederr": ("Verified with error workunits", {"eq": {"status": WuStatus.VERIFIED_ERROR}}),
1785               "cancelled": ("Cancelled workunits", {"eq": {"status": WuStatus.CANCELLED}}),
1786               "all": ("All existing workunits", {})
1787              }
1788
1789    use_pool = False
1790
1791    parser = argparse.ArgumentParser()
1792    parser.add_argument('-dbfile', help='Name of the database file')
1793    parser.add_argument('-create', action="store_true",
1794                        help='Create the database tables if they do not exist')
1795    parser.add_argument('-add', action="store_true",
1796                        help='Add new workunits. Contents of WU(s) are '
1797                        'read from stdin, separated by blank line')
1798    parser.add_argument('-assign', nargs = 1, metavar = 'clientid',
1799                        help = 'Assign an available WU to clientid')
1800    parser.add_argument('-cancel', action="store_true",
1801                        help = 'Cancel selected WUs')
1802    parser.add_argument('-expire', action="store_true",
1803                        help = 'Expire selected WUs')
1804    # parser.add_argument('-setstatus', metavar = 'STATUS',
1805    #                    help = 'Forcibly set selected workunits to status (integer)')
1806    parser.add_argument('-prio', metavar = 'N',
1807                        help = 'If used with -add, newly added WUs '
1808                        'receive priority N')
1809    parser.add_argument('-limit', metavar = 'N',
1810                        help = 'Limit number of records in queries',
1811                        default = None)
1812    parser.add_argument('-result', nargs = 6,
1813                        metavar = ('wuid', 'clientid', 'filename', 'filepath',
1814                                   'filetype', 'command'),
1815                        help = 'Return a result for wu from client')
1816    parser.add_argument('-test', action="store_true",
1817                        help='Run some self tests')
1818    parser.add_argument('-debug', help='Set debugging level')
1819    parser.add_argument('-setdict', nargs = 4,
1820                        metavar = ("dictname", "keyname", "type", "keyvalue"),
1821                        help='Set an entry of a DB-backed dictionary')
1822
1823    parser.add_argument('-wuid', help="Select workunit with given id",
1824                        metavar="WUID")
1825    for arg in queries:
1826        parser.add_argument('-' + arg, action="store_true", required=False,
1827                            help="Select %s" % queries[arg][0].lower())
1828    parser.add_argument('-dump', nargs='?', default = None, const = "all",
1829                        metavar = "FIELD",
1830                        help='Dump WU contents, optionally a single field')
1831    parser.add_argument('-sort', metavar = "FIELD",
1832                        help='With -dump, sort output by FIELD')
1833    # Parse command line, store as dictionary
1834    args = vars(parser.parse_args())
1835    # print(args)
1836
1837    dbname = "wudb"
1838    if args["dbfile"]:
1839        dbname = args["dbfile"]
1840
1841    if args["test"]:
1842        import doctest
1843        doctest.testmod()
1844
1845    if args["debug"]:
1846        DEBUG = int(args["debug"])
1847    prio = 0
1848    if args["prio"]:
1849        prio = int(args["prio"][0])
1850    limit = args["limit"]
1851
1852    db = DBFactory('db:sqlite3://%s' % dbname)
1853
1854    if use_pool:
1855        db_pool = DbThreadPool(db)
1856    else:
1857        db_pool = WuAccess(db)
1858
1859    if args["create"]:
1860        db_pool.create_tables()
1861    if args["add"]:
1862        s = ""
1863        wus = []
1864        for line in sys.stdin:
1865            if line == "\n":
1866                wus.append(s)
1867                s = ""
1868            else:
1869                s += line
1870        if s != "":
1871            wus.append(s)
1872        db_pool.create(wus, priority=prio)
1873
1874    # Functions for queries
1875    queries_list = []
1876    for (arg, (msg, condition)) in queries.items():
1877        if args[arg]:
1878            queries_list.append([msg, condition])
1879    if args["wuid"]:
1880        for wuid in args["wuid"].split(","):
1881            msg = "Workunit %s" % wuid
1882            condition = {"eq": {"wuid": wuid}}
1883            queries_list.append([msg, condition])
1884
1885    for (msg, condition) in queries_list:
1886        print("%s: " % msg)
1887        if not args["dump"]:
1888            count = db_pool.count(limit=args["limit"], **condition)
1889            print (count)
1890        else:
1891            wus = db_pool.query(limit=args["limit"], **condition)
1892            if wus is None:
1893                print("0")
1894            else:
1895                print (len(wus))
1896                if args["sort"]:
1897                    wus.sort(key=lambda wu: str(wu[args["sort"]]))
1898                if args["dump"] == "all":
1899                    print(WuAccess.to_str(wus))
1900                else:
1901                    for wu in wus:
1902                        print(wu[args["dump"]])
1903        if args["cancel"]:
1904            print("Cancelling selected workunits")
1905            db_pool.cancel_by_condition(**condition)
1906        if args["expire"]:
1907            print("Expiring selected workunits")
1908            db_pool.set_status(WuStatus.NEED_RESUBMIT, commit=True, **condition)
1909        # if args["setstatus"]:
1910        #    db_pool.set_status(int(args["setstatus"]), **condition)
1911
1912    # Dict manipulation
1913    if args["setdict"]:
1914        (name, keyname, itemtype, keyvalue) = args["setdict"]
1915        # Type-cast value to the specified type
1916        value =  getattr(__builtins__, itemtype)(keyvalue)
1917        dbdict = DictDbAccess(dbname, name)
1918        dbdict[keyname] = value
1919        del(dbdict)
1920
1921    # Functions for testing
1922    if args["assign"]:
1923        clientid = args["assign"][0]
1924        wus = db_pool.assign(clientid)
1925
1926    if args["result"]:
1927        result = args["result"]
1928        db_pool.result(result.wuid, result.clientid, result[2:])
1929
1930    if use_pool:
1931        db_pool.terminate()
1932# }
1933
1934# Local Variables:
1935# version-control: t
1936# End:
1937