1"""A Socket subclass that adds some serialization methods."""
2
3import zlib
4import pickle
5
6import numpy
7
8import zmq
9
10
11class SerializingSocket(zmq.Socket):
12    """A class with some extra serialization methods
13
14    send_zipped_pickle is just like send_pyobj, but uses
15    zlib to compress the stream before sending.
16
17    send_array sends numpy arrays with metadata necessary
18    for reconstructing the array on the other side (dtype,shape).
19    """
20
21    def send_zipped_pickle(self, obj, flags=0, protocol=-1):
22        """pack and compress an object with pickle and zlib."""
23        pobj = pickle.dumps(obj, protocol)
24        zobj = zlib.compress(pobj)
25        print('zipped pickle is %i bytes' % len(zobj))
26        return self.send(zobj, flags=flags)
27
28    def recv_zipped_pickle(self, flags=0):
29        """reconstruct a Python object sent with zipped_pickle"""
30        zobj = self.recv(flags)
31        pobj = zlib.decompress(zobj)
32        return pickle.loads(pobj)
33
34    def send_array(self, A, flags=0, copy=True, track=False):
35        """send a numpy array with metadata"""
36        md = dict(
37            dtype=str(A.dtype),
38            shape=A.shape,
39        )
40        self.send_json(md, flags | zmq.SNDMORE)
41        return self.send(A, flags, copy=copy, track=track)
42
43    def recv_array(self, flags=0, copy=True, track=False):
44        """recv a numpy array"""
45        md = self.recv_json(flags=flags)
46        msg = self.recv(flags=flags, copy=copy, track=track)
47        A = numpy.frombuffer(msg, dtype=md['dtype'])
48        return A.reshape(md['shape'])
49
50
51class SerializingContext(zmq.Context):
52    _socket_class = SerializingSocket
53
54
55def main():
56    ctx = SerializingContext()
57    req = ctx.socket(zmq.REQ)
58    rep = ctx.socket(zmq.REP)
59
60    rep.bind('inproc://a')
61    req.connect('inproc://a')
62    A = numpy.ones((1024, 1024))
63    print("Array is %i bytes" % (A.nbytes))
64
65    # send/recv with pickle+zip
66    req.send_zipped_pickle(A)
67    B = rep.recv_zipped_pickle()
68    # now try non-copying version
69    rep.send_array(A, copy=False)
70    C = req.recv_array(copy=False)
71    print("Checking zipped pickle...")
72    print("Okay" if (A == B).all() else "Failed")
73    print("Checking send_array...")
74    print("Okay" if (C == B).all() else "Failed")
75
76
77if __name__ == '__main__':
78    main()
79