1# Simple example presenting how persistent ID can be used to pickle
2# external objects by reference.
3
4import pickle
5import sqlite3
6from collections import namedtuple
7
8# Simple class representing a record in our database.
9MemoRecord = namedtuple("MemoRecord", "key, task")
10
11class DBPickler(pickle.Pickler):
12
13    def persistent_id(self, obj):
14        # Instead of pickling MemoRecord as a regular class instance, we emit a
15        # persistent ID.
16        if isinstance(obj, MemoRecord):
17            # Here, our persistent ID is simply a tuple, containing a tag and a
18            # key, which refers to a specific record in the database.
19            return ("MemoRecord", obj.key)
20        else:
21            # If obj does not have a persistent ID, return None. This means obj
22            # needs to be pickled as usual.
23            return None
24
25
26class DBUnpickler(pickle.Unpickler):
27
28    def __init__(self, file, connection):
29        super().__init__(file)
30        self.connection = connection
31
32    def persistent_load(self, pid):
33        # This method is invoked whenever a persistent ID is encountered.
34        # Here, pid is the tuple returned by DBPickler.
35        cursor = self.connection.cursor()
36        type_tag, key_id = pid
37        if type_tag == "MemoRecord":
38            # Fetch the referenced record from the database and return it.
39            cursor.execute("SELECT * FROM memos WHERE key=?", (str(key_id),))
40            key, task = cursor.fetchone()
41            return MemoRecord(key, task)
42        else:
43            # Always raises an error if you cannot return the correct object.
44            # Otherwise, the unpickler will think None is the object referenced
45            # by the persistent ID.
46            raise pickle.UnpicklingError("unsupported persistent object")
47
48
49def main():
50    import io
51    import pprint
52
53    # Initialize and populate our database.
54    conn = sqlite3.connect(":memory:")
55    cursor = conn.cursor()
56    cursor.execute("CREATE TABLE memos(key INTEGER PRIMARY KEY, task TEXT)")
57    tasks = (
58        'give food to fish',
59        'prepare group meeting',
60        'fight with a zebra',
61        )
62    for task in tasks:
63        cursor.execute("INSERT INTO memos VALUES(NULL, ?)", (task,))
64
65    # Fetch the records to be pickled.
66    cursor.execute("SELECT * FROM memos")
67    memos = [MemoRecord(key, task) for key, task in cursor]
68    # Save the records using our custom DBPickler.
69    file = io.BytesIO()
70    DBPickler(file).dump(memos)
71
72    print("Pickled records:")
73    pprint.pprint(memos)
74
75    # Update a record, just for good measure.
76    cursor.execute("UPDATE memos SET task='learn italian' WHERE key=1")
77
78    # Load the records from the pickle data stream.
79    file.seek(0)
80    memos = DBUnpickler(file, conn).load()
81
82    print("Unpickled records:")
83    pprint.pprint(memos)
84
85
86if __name__ == '__main__':
87    main()
88