1import sqlite3 as sql 2import numpy as np 3import os 4import sys 5import warnings 6import pickle as pkl 7 8LEGACY_PYTHON = sys.version_info[0] < 3 9 10CREATE_TEMPLATE = "CREATE TABLE {} (iteration INTEGER PRIMARY KEY, {})" 11INSERT_TEMPLATE = "INSERT INTO {} VALUES (?, {})" 12 13byte_type = str if LEGACY_PYTHON else bytes 14 15def customize_create_template(colnames, tablename): 16 """ 17 Change the CREATE TABLE statement to bind a set of column names and table name 18 """ 19 return CREATE_TEMPLATE.format(tablename, ' , '.join(['"' + param + '"' + ' BLOB' 20 for param in colnames])) 21def customize_insert_template(colnames, tablename): 22 """ 23 Change the INSERT INTO statement to bind to a set of column names and table name 24 """ 25 data_part = ' , '.join(['?' for _ in colnames]) 26 return INSERT_TEMPLATE.format(tablename, data_part) 27 28def start_sql(model, tracename='model_trace.db'): 29 """ 30 Start a SQLite connection to a local database, specified by tracename. 31 """ 32 if os.path.isfile(tracename): 33 raise Exception('Will not overwrite existing trace {}'.format(tracename)) 34 cxn = sql.connect(tracename) 35 cursor = cxn.cursor() 36 cursor.execute(customize_create_template(model.traced_params, 'trace')) 37 return cxn, cursor 38 39def head_to_sql(model, cursor, connection): 40 """ 41 Send the most recent trace point to the sql database. 42 """ 43 point_to_sql(model, cursor, connection, index=-1) 44 45def point_to_sql(model, cursor, connection, index=0): 46 """ 47 Send an arbitrary index point to the database 48 """ 49 if index < 0: 50 iteration = model.cycles - (index + 1) 51 else: 52 iteration = index 53 ordered_point = (serialize(model.trace[param, index]) for param in model.traced_params) 54 to_insert = [iteration] 55 to_insert.extend(list(ordered_point)) 56 cursor.execute(customize_insert_template(model.traced_params, 'trace'), tuple(to_insert)) 57 connection.commit() 58 59def trace_to_sql(model, cursor, connection): 60 """ 61 Send a model's entire trace to the database 62 """ 63 for i in range(model.cycles): 64 ordered_point = (serialize(model.trace[param, i]) for param in model.traced_params) 65 to_insert = [i] 66 to_insert.extend(list(ordered_point)) 67 cursor.execute(customize_insert_template(model.traced_params, 'trace'), tuple(to_insert)) 68 connection.commit() 69 70def trace_from_sql(filename, table='trace'): 71 """ 72 Reconstruct a model trace from the database 73 """ 74 #connect, parse header, setup trace object, then deserialize the sql 75 cxn = sql.connect(filename) 76 pragma = cxn.execute('PRAGMA table_info({})'.format(table)).fetchall() 77 colnames = [t[1] for t in pragma] 78 data = cxn.execute('SELECT * FROM {}'.format(table)).fetchall() 79 cxn.close() 80 records = zip(colnames, map(list, zip(*data))) 81 82 # Import must occur here otherwise there's a circularity issue 83 84 from .abstracts import Trace 85 86 if table == 'trace': 87 out = Trace(**{colname:[maybe_deserialize(entry) for entry in column] 88 for colname, column in records}) 89 else: 90 out = Trace(**{colname:maybe_deserialize(column[0]) 91 for colname, column in records}) 92 return out 93 94def model_to_sql(model, cursor, connection): 95 """ 96 Serialize an entire model into a sqlite database. This serializes the trace 97 into the `trace` table, the state into the `state` table, and the model class 98 into the `model` table. All items are pickled using their own dumps method, if possible. 99 Otherwise, objects are reduced using dill.dumps, which is then passed to sqlite as a BLOB 100 """ 101 trace_to_sql(model, cursor, connection) 102 frozen_state_keys = list(model.state.varnames) 103 frozen_state = (serialize(model.state[k]) for k in frozen_state_keys) 104 cursor.execute(customize_create_template(frozen_state_keys, 'state')) 105 insert_template = customize_insert_template(frozen_state_keys, 'state') 106 to_insert = [model.cycles] 107 to_insert.extend(list(frozen_state)) 108 cursor.execute(insert_template, tuple(to_insert)) 109 class_pkl = pkl.dumps(model.__class__) #want instance, not whole model 110 cursor.execute(customize_create_template(['model_class'], 'model')) 111 cursor.execute(customize_insert_template(['class'], 'model'), (None, class_pkl)) 112 connection.commit() 113 114def model_from_sql(filename): 115 """ 116 Reconstruct a model from a sqlite table with a given trace, state, and model tables. 117 118 If the serialization fails for the trace or state, the resulting 119 trace/state may contain raw binary strings. If the serialization fails for the model/there 120 is no model table, the function will fail. To just extract the trace or the state, 121 use trace_from_sql. 122 """ 123 trace = trace_from_sql(filename) 124 state = trace_from_sql(filename, table='state') 125 cxn = sql.connect(filename) 126 model_class = cxn.execute('SELECT model_class FROM model') 127 model_class = pkl.loads(model_class.fetchall()[0][0]) 128 try: 129 model_class(**state) 130 except: 131 warn('initializing model {} from state failed! ' 132 'Returning trace, state, model.'.format(model_class), stacklevel=2) 133 return model_class, trace, state 134 135def maybe_deserialize(maybe_bytestring): 136 """ 137 This attempts to deserialize an object, but may return the original object 138 if no deserialization is successful. 139 """ 140 try: 141 import dill 142 except ImportError as E: 143 msg = 'The `dill` module is required to use the sqlite backend fully.' 144 warnings.warn(msg, stacklevel=2) 145 146 if isinstance(maybe_bytestring, (list, tuple)): 147 return type(maybe_bytestring)([maybe_deserialize(byte_element) 148 for byte_element in maybe_bytestring]) 149 try: 150 return pkl.loads(maybe_bytestring) 151 except: 152 try: 153 return dill.loads(maybe_bytestring) 154 except: 155 try: 156 return float(maybe_bytestring) 157 except: 158 return maybe_bytestring 159 160def serialize(v): 161 """ 162 This serializes an object, but may return the original object if serialization 163 is not successful. 164 """ 165 if hasattr(v, 'dumps'): 166 return v.dumps() 167 elif isinstance(v, (float, int)): 168 return v 169 else: 170 return pkl.dumps(v) 171