1"""A Socket subclass that adds some serialization methods.""" 2 3import zlib 4import pickle 5 6import numpy 7 8import zmq 9 10 11class SerializingSocket(zmq.Socket): 12 """A class with some extra serialization methods 13 14 send_zipped_pickle is just like send_pyobj, but uses 15 zlib to compress the stream before sending. 16 17 send_array sends numpy arrays with metadata necessary 18 for reconstructing the array on the other side (dtype,shape). 19 """ 20 21 def send_zipped_pickle(self, obj, flags=0, protocol=-1): 22 """pack and compress an object with pickle and zlib.""" 23 pobj = pickle.dumps(obj, protocol) 24 zobj = zlib.compress(pobj) 25 print('zipped pickle is %i bytes' % len(zobj)) 26 return self.send(zobj, flags=flags) 27 28 def recv_zipped_pickle(self, flags=0): 29 """reconstruct a Python object sent with zipped_pickle""" 30 zobj = self.recv(flags) 31 pobj = zlib.decompress(zobj) 32 return pickle.loads(pobj) 33 34 def send_array(self, A, flags=0, copy=True, track=False): 35 """send a numpy array with metadata""" 36 md = dict( 37 dtype=str(A.dtype), 38 shape=A.shape, 39 ) 40 self.send_json(md, flags | zmq.SNDMORE) 41 return self.send(A, flags, copy=copy, track=track) 42 43 def recv_array(self, flags=0, copy=True, track=False): 44 """recv a numpy array""" 45 md = self.recv_json(flags=flags) 46 msg = self.recv(flags=flags, copy=copy, track=track) 47 A = numpy.frombuffer(msg, dtype=md['dtype']) 48 return A.reshape(md['shape']) 49 50 51class SerializingContext(zmq.Context): 52 _socket_class = SerializingSocket 53 54 55def main(): 56 ctx = SerializingContext() 57 req = ctx.socket(zmq.REQ) 58 rep = ctx.socket(zmq.REP) 59 60 rep.bind('inproc://a') 61 req.connect('inproc://a') 62 A = numpy.ones((1024, 1024)) 63 print("Array is %i bytes" % (A.nbytes)) 64 65 # send/recv with pickle+zip 66 req.send_zipped_pickle(A) 67 B = rep.recv_zipped_pickle() 68 # now try non-copying version 69 rep.send_array(A, copy=False) 70 C = req.recv_array(copy=False) 71 print("Checking zipped pickle...") 72 print("Okay" if (A == B).all() else "Failed") 73 print("Checking send_array...") 74 print("Okay" if (C == B).all() else "Failed") 75 76 77if __name__ == '__main__': 78 main() 79