1import sqlite3 as sql
2import numpy as np
3import os
4import sys
5import warnings
6import pickle as pkl
7
8LEGACY_PYTHON = sys.version_info[0] < 3
9
10CREATE_TEMPLATE = "CREATE TABLE {} (iteration INTEGER PRIMARY KEY, {})"
11INSERT_TEMPLATE = "INSERT INTO {} VALUES (?, {})"
12
13byte_type = str if LEGACY_PYTHON else bytes
14
15def customize_create_template(colnames, tablename):
16    """
17    Change the CREATE TABLE statement to bind a set of column names and table name
18    """
19    return CREATE_TEMPLATE.format(tablename, ' , '.join(['"' + param + '"' + ' BLOB'
20                                            for param in colnames]))
21def customize_insert_template(colnames, tablename):
22    """
23    Change the INSERT INTO statement to bind to a set of column names and table name
24    """
25    data_part = ' , '.join(['?' for _ in colnames])
26    return INSERT_TEMPLATE.format(tablename, data_part)
27
28def start_sql(model, tracename='model_trace.db'):
29    """
30    Start a SQLite connection to a local database, specified by tracename.
31    """
32    if os.path.isfile(tracename):
33        raise Exception('Will not overwrite existing trace {}'.format(tracename))
34    cxn = sql.connect(tracename)
35    cursor = cxn.cursor()
36    cursor.execute(customize_create_template(model.traced_params, 'trace'))
37    return cxn, cursor
38
39def head_to_sql(model, cursor, connection):
40    """
41    Send the most recent trace point to the sql database.
42    """
43    point_to_sql(model, cursor, connection, index=-1)
44
45def point_to_sql(model, cursor, connection, index=0):
46    """
47    Send an arbitrary index point to the database
48    """
49    if index < 0:
50        iteration = model.cycles - (index + 1)
51    else:
52        iteration = index
53    ordered_point = (serialize(model.trace[param, index]) for param in model.traced_params)
54    to_insert = [iteration]
55    to_insert.extend(list(ordered_point))
56    cursor.execute(customize_insert_template(model.traced_params, 'trace'), tuple(to_insert))
57    connection.commit()
58
59def trace_to_sql(model, cursor, connection):
60    """
61    Send a model's entire trace to the database
62    """
63    for i in range(model.cycles):
64        ordered_point = (serialize(model.trace[param, i]) for param in model.traced_params)
65        to_insert = [i]
66        to_insert.extend(list(ordered_point))
67        cursor.execute(customize_insert_template(model.traced_params, 'trace'), tuple(to_insert))
68    connection.commit()
69
70def trace_from_sql(filename, table='trace'):
71    """
72    Reconstruct a model trace from the database
73    """
74    #connect, parse header, setup trace object, then deserialize the sql
75    cxn = sql.connect(filename)
76    pragma = cxn.execute('PRAGMA table_info({})'.format(table)).fetchall()
77    colnames = [t[1] for t in pragma]
78    data = cxn.execute('SELECT * FROM {}'.format(table)).fetchall()
79    cxn.close()
80    records = zip(colnames, map(list, zip(*data)))
81
82    # Import must occur here otherwise there's a circularity issue
83
84    from .abstracts import Trace
85
86    if table == 'trace':
87        out = Trace(**{colname:[maybe_deserialize(entry) for entry in column]
88                  for colname, column in records})
89    else:
90        out = Trace(**{colname:maybe_deserialize(column[0])
91                  for colname, column in records})
92    return out
93
94def model_to_sql(model, cursor, connection):
95    """
96    Serialize an entire model into a sqlite database. This serializes the trace
97    into the `trace` table, the state into the `state` table, and the model class
98    into the `model` table. All items are pickled using their own dumps method, if possible.
99    Otherwise, objects are reduced using dill.dumps, which is then passed to sqlite as a BLOB
100    """
101    trace_to_sql(model, cursor, connection)
102    frozen_state_keys = list(model.state.varnames)
103    frozen_state = (serialize(model.state[k]) for k in frozen_state_keys)
104    cursor.execute(customize_create_template(frozen_state_keys, 'state'))
105    insert_template = customize_insert_template(frozen_state_keys, 'state')
106    to_insert = [model.cycles]
107    to_insert.extend(list(frozen_state))
108    cursor.execute(insert_template, tuple(to_insert))
109    class_pkl = pkl.dumps(model.__class__) #want instance, not whole model
110    cursor.execute(customize_create_template(['model_class'], 'model'))
111    cursor.execute(customize_insert_template(['class'], 'model'), (None, class_pkl))
112    connection.commit()
113
114def model_from_sql(filename):
115    """
116    Reconstruct a model from a sqlite table with a given trace, state, and model tables.
117
118    If the serialization fails for the trace or state, the resulting
119    trace/state may contain raw binary strings. If the serialization fails for the model/there
120    is no model table, the function will fail. To just extract the trace or the state,
121    use trace_from_sql.
122    """
123    trace = trace_from_sql(filename)
124    state = trace_from_sql(filename, table='state')
125    cxn = sql.connect(filename)
126    model_class = cxn.execute('SELECT model_class FROM model')
127    model_class = pkl.loads(model_class.fetchall()[0][0])
128    try:
129        model_class(**state)
130    except:
131        warn('initializing model {} from state failed! '
132             'Returning trace, state, model.'.format(model_class), stacklevel=2)
133        return model_class, trace, state
134
135def maybe_deserialize(maybe_bytestring):
136    """
137    This attempts to deserialize an object, but may return the original object
138    if no deserialization is successful.
139    """
140    try:
141        import dill
142    except ImportError as E:
143        msg = 'The `dill` module is required to use the sqlite backend fully.'
144        warnings.warn(msg, stacklevel=2)
145
146    if isinstance(maybe_bytestring, (list, tuple)):
147        return type(maybe_bytestring)([maybe_deserialize(byte_element)
148                                        for byte_element in maybe_bytestring])
149    try:
150        return pkl.loads(maybe_bytestring)
151    except:
152        try:
153            return dill.loads(maybe_bytestring)
154        except:
155            try:
156                return float(maybe_bytestring)
157            except:
158                return maybe_bytestring
159
160def serialize(v):
161    """
162    This serializes an object, but may return the original object if serialization
163    is not successful.
164    """
165    if hasattr(v, 'dumps'):
166        return v.dumps()
167    elif isinstance(v, (float, int)):
168        return v
169    else:
170        return pkl.dumps(v)
171