1# 2# Licensed to the Apache Software Foundation (ASF) under one or more 3# contributor license agreements. See the NOTICE file distributed with 4# this work for additional information regarding copyright ownership. 5# The ASF licenses this file to You under the Apache License, Version 2.0 6# (the "License"); you may not use this file except in compliance with 7# the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18""" 19PySpark supports custom serializers for transferring data; this can improve 20performance. 21 22By default, PySpark uses L{PickleSerializer} to serialize objects using Python's 23C{cPickle} serializer, which can serialize nearly any Python object. 24Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be 25faster. 26 27The serializer is chosen when creating L{SparkContext}: 28 29>>> from pyspark.context import SparkContext 30>>> from pyspark.serializers import MarshalSerializer 31>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) 32>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) 33[0, 2, 4, 6, 8, 10, 12, 14, 16, 18] 34>>> sc.stop() 35 36PySpark serialize objects in batches; By default, the batch size is chosen based 37on the size of objects, also configurable by SparkContext's C{batchSize} parameter: 38 39>>> sc = SparkContext('local', 'test', batchSize=2) 40>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) 41 42Behind the scenes, this creates a JavaRDD with four partitions, each of 43which contains two batches of two objects: 44 45>>> rdd.glom().collect() 46[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] 47>>> int(rdd._jrdd.count()) 488 49>>> sc.stop() 50""" 51 52import sys 53from itertools import chain, product 54import marshal 55import struct 56import types 57import collections 58import zlib 59import itertools 60 61if sys.version < '3': 62 import cPickle as pickle 63 protocol = 2 64 from itertools import izip as zip, imap as map 65else: 66 import pickle 67 protocol = 3 68 xrange = range 69 70from pyspark import cloudpickle 71 72 73__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] 74 75 76class SpecialLengths(object): 77 END_OF_DATA_SECTION = -1 78 PYTHON_EXCEPTION_THROWN = -2 79 TIMING_DATA = -3 80 END_OF_STREAM = -4 81 NULL = -5 82 83 84class Serializer(object): 85 86 def dump_stream(self, iterator, stream): 87 """ 88 Serialize an iterator of objects to the output stream. 89 """ 90 raise NotImplementedError 91 92 def load_stream(self, stream): 93 """ 94 Return an iterator of deserialized objects from the input stream. 95 """ 96 raise NotImplementedError 97 98 def _load_stream_without_unbatching(self, stream): 99 """ 100 Return an iterator of deserialized batches (lists) of objects from the input stream. 101 if the serializer does not operate on batches the default implementation returns an 102 iterator of single element lists. 103 """ 104 return map(lambda x: [x], self.load_stream(stream)) 105 106 # Note: our notion of "equality" is that output generated by 107 # equal serializers can be deserialized using the same serializer. 108 109 # This default implementation handles the simple cases; 110 # subclasses should override __eq__ as appropriate. 111 112 def __eq__(self, other): 113 return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ 114 115 def __ne__(self, other): 116 return not self.__eq__(other) 117 118 def __repr__(self): 119 return "%s()" % self.__class__.__name__ 120 121 def __hash__(self): 122 return hash(str(self)) 123 124 125class FramedSerializer(Serializer): 126 127 """ 128 Serializer that writes objects as a stream of (length, data) pairs, 129 where C{length} is a 32-bit integer and data is C{length} bytes. 130 """ 131 132 def __init__(self): 133 # On Python 2.6, we can't write bytearrays to streams, so we need to convert them 134 # to strings first. Check if the version number is that old. 135 self._only_write_strings = sys.version_info[0:2] <= (2, 6) 136 137 def dump_stream(self, iterator, stream): 138 for obj in iterator: 139 self._write_with_length(obj, stream) 140 141 def load_stream(self, stream): 142 while True: 143 try: 144 yield self._read_with_length(stream) 145 except EOFError: 146 return 147 148 def _write_with_length(self, obj, stream): 149 serialized = self.dumps(obj) 150 if serialized is None: 151 raise ValueError("serialized value should not be None") 152 if len(serialized) > (1 << 31): 153 raise ValueError("can not serialize object larger than 2G") 154 write_int(len(serialized), stream) 155 if self._only_write_strings: 156 stream.write(str(serialized)) 157 else: 158 stream.write(serialized) 159 160 def _read_with_length(self, stream): 161 length = read_int(stream) 162 if length == SpecialLengths.END_OF_DATA_SECTION: 163 raise EOFError 164 elif length == SpecialLengths.NULL: 165 return None 166 obj = stream.read(length) 167 if len(obj) < length: 168 raise EOFError 169 return self.loads(obj) 170 171 def dumps(self, obj): 172 """ 173 Serialize an object into a byte array. 174 When batching is used, this will be called with an array of objects. 175 """ 176 raise NotImplementedError 177 178 def loads(self, obj): 179 """ 180 Deserialize an object from a byte array. 181 """ 182 raise NotImplementedError 183 184 185class BatchedSerializer(Serializer): 186 187 """ 188 Serializes a stream of objects in batches by calling its wrapped 189 Serializer with streams of objects. 190 """ 191 192 UNLIMITED_BATCH_SIZE = -1 193 UNKNOWN_BATCH_SIZE = 0 194 195 def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): 196 self.serializer = serializer 197 self.batchSize = batchSize 198 199 def _batched(self, iterator): 200 if self.batchSize == self.UNLIMITED_BATCH_SIZE: 201 yield list(iterator) 202 elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"): 203 n = len(iterator) 204 for i in xrange(0, n, self.batchSize): 205 yield iterator[i: i + self.batchSize] 206 else: 207 items = [] 208 count = 0 209 for item in iterator: 210 items.append(item) 211 count += 1 212 if count == self.batchSize: 213 yield items 214 items = [] 215 count = 0 216 if items: 217 yield items 218 219 def dump_stream(self, iterator, stream): 220 self.serializer.dump_stream(self._batched(iterator), stream) 221 222 def load_stream(self, stream): 223 return chain.from_iterable(self._load_stream_without_unbatching(stream)) 224 225 def _load_stream_without_unbatching(self, stream): 226 return self.serializer.load_stream(stream) 227 228 def __repr__(self): 229 return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) 230 231 232class FlattenedValuesSerializer(BatchedSerializer): 233 234 """ 235 Serializes a stream of list of pairs, split the list of values 236 which contain more than a certain number of objects to make them 237 have similar sizes. 238 """ 239 def __init__(self, serializer, batchSize=10): 240 BatchedSerializer.__init__(self, serializer, batchSize) 241 242 def _batched(self, iterator): 243 n = self.batchSize 244 for key, values in iterator: 245 for i in range(0, len(values), n): 246 yield key, values[i:i + n] 247 248 def load_stream(self, stream): 249 return self.serializer.load_stream(stream) 250 251 def __repr__(self): 252 return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize) 253 254 255class AutoBatchedSerializer(BatchedSerializer): 256 """ 257 Choose the size of batch automatically based on the size of object 258 """ 259 260 def __init__(self, serializer, bestSize=1 << 16): 261 BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE) 262 self.bestSize = bestSize 263 264 def dump_stream(self, iterator, stream): 265 batch, best = 1, self.bestSize 266 iterator = iter(iterator) 267 while True: 268 vs = list(itertools.islice(iterator, batch)) 269 if not vs: 270 break 271 272 bytes = self.serializer.dumps(vs) 273 write_int(len(bytes), stream) 274 stream.write(bytes) 275 276 size = len(bytes) 277 if size < best: 278 batch *= 2 279 elif size > best * 10 and batch > 1: 280 batch //= 2 281 282 def __repr__(self): 283 return "AutoBatchedSerializer(%s)" % self.serializer 284 285 286class CartesianDeserializer(Serializer): 287 288 """ 289 Deserializes the JavaRDD cartesian() of two PythonRDDs. 290 Due to pyspark batching we cannot simply use the result of the Java RDD cartesian, 291 we additionally need to do the cartesian within each pair of batches. 292 """ 293 294 def __init__(self, key_ser, val_ser): 295 self.key_ser = key_ser 296 self.val_ser = val_ser 297 298 def _load_stream_without_unbatching(self, stream): 299 key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) 300 val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) 301 for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): 302 # for correctness with repeated cartesian/zip this must be returned as one batch 303 yield product(key_batch, val_batch) 304 305 def load_stream(self, stream): 306 return chain.from_iterable(self._load_stream_without_unbatching(stream)) 307 308 def __repr__(self): 309 return "CartesianDeserializer(%s, %s)" % \ 310 (str(self.key_ser), str(self.val_ser)) 311 312 313class PairDeserializer(Serializer): 314 315 """ 316 Deserializes the JavaRDD zip() of two PythonRDDs. 317 Due to pyspark batching we cannot simply use the result of the Java RDD zip, 318 we additionally need to do the zip within each pair of batches. 319 """ 320 321 def __init__(self, key_ser, val_ser): 322 self.key_ser = key_ser 323 self.val_ser = val_ser 324 325 def _load_stream_without_unbatching(self, stream): 326 key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) 327 val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) 328 for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): 329 if len(key_batch) != len(val_batch): 330 raise ValueError("Can not deserialize PairRDD with different number of items" 331 " in batches: (%d, %d)" % (len(key_batch), len(val_batch))) 332 # for correctness with repeated cartesian/zip this must be returned as one batch 333 yield zip(key_batch, val_batch) 334 335 def load_stream(self, stream): 336 return chain.from_iterable(self._load_stream_without_unbatching(stream)) 337 338 def __repr__(self): 339 return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) 340 341 342class NoOpSerializer(FramedSerializer): 343 344 def loads(self, obj): 345 return obj 346 347 def dumps(self, obj): 348 return obj 349 350 351# Hook namedtuple, make it picklable 352 353__cls = {} 354 355 356def _restore(name, fields, value): 357 """ Restore an object of namedtuple""" 358 k = (name, fields) 359 cls = __cls.get(k) 360 if cls is None: 361 cls = collections.namedtuple(name, fields) 362 __cls[k] = cls 363 return cls(*value) 364 365 366def _hack_namedtuple(cls): 367 """ Make class generated by namedtuple picklable """ 368 name = cls.__name__ 369 fields = cls._fields 370 371 def __reduce__(self): 372 return (_restore, (name, fields, tuple(self))) 373 cls.__reduce__ = __reduce__ 374 cls._is_namedtuple_ = True 375 return cls 376 377 378def _hijack_namedtuple(): 379 """ Hack namedtuple() to make it picklable """ 380 # hijack only one time 381 if hasattr(collections.namedtuple, "__hijack"): 382 return 383 384 global _old_namedtuple # or it will put in closure 385 global _old_namedtuple_kwdefaults # or it will put in closure too 386 387 def _copy_func(f): 388 return types.FunctionType(f.__code__, f.__globals__, f.__name__, 389 f.__defaults__, f.__closure__) 390 391 def _kwdefaults(f): 392 # __kwdefaults__ contains the default values of keyword-only arguments which are 393 # introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple 394 # are as below: 395 # 396 # - Does not exist in Python 2. 397 # - Returns None in <= Python 3.5.x. 398 # - Returns a dictionary containing the default values to the keys from Python 3.6.x 399 # (See https://bugs.python.org/issue25628). 400 kargs = getattr(f, "__kwdefaults__", None) 401 if kargs is None: 402 return {} 403 else: 404 return kargs 405 406 _old_namedtuple = _copy_func(collections.namedtuple) 407 _old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple) 408 409 def namedtuple(*args, **kwargs): 410 for k, v in _old_namedtuple_kwdefaults.items(): 411 kwargs[k] = kwargs.get(k, v) 412 cls = _old_namedtuple(*args, **kwargs) 413 return _hack_namedtuple(cls) 414 415 # replace namedtuple with new one 416 collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults 417 collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple 418 collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple 419 collections.namedtuple.__code__ = namedtuple.__code__ 420 collections.namedtuple.__hijack = 1 421 422 # hack the cls already generated by namedtuple 423 # those created in other module can be pickled as normal, 424 # so only hack those in __main__ module 425 for n, o in sys.modules["__main__"].__dict__.items(): 426 if (type(o) is type and o.__base__ is tuple 427 and hasattr(o, "_fields") 428 and "__reduce__" not in o.__dict__): 429 _hack_namedtuple(o) # hack inplace 430 431 432_hijack_namedtuple() 433 434 435class PickleSerializer(FramedSerializer): 436 437 """ 438 Serializes objects using Python's pickle serializer: 439 440 http://docs.python.org/2/library/pickle.html 441 442 This serializer supports nearly any Python object, but may 443 not be as fast as more specialized serializers. 444 """ 445 446 def dumps(self, obj): 447 return pickle.dumps(obj, protocol) 448 449 if sys.version >= '3': 450 def loads(self, obj, encoding="bytes"): 451 return pickle.loads(obj, encoding=encoding) 452 else: 453 def loads(self, obj, encoding=None): 454 return pickle.loads(obj) 455 456 457class CloudPickleSerializer(PickleSerializer): 458 459 def dumps(self, obj): 460 return cloudpickle.dumps(obj, 2) 461 462 463class MarshalSerializer(FramedSerializer): 464 465 """ 466 Serializes objects using Python's Marshal serializer: 467 468 http://docs.python.org/2/library/marshal.html 469 470 This serializer is faster than PickleSerializer but supports fewer datatypes. 471 """ 472 473 def dumps(self, obj): 474 return marshal.dumps(obj) 475 476 def loads(self, obj): 477 return marshal.loads(obj) 478 479 480class AutoSerializer(FramedSerializer): 481 482 """ 483 Choose marshal or pickle as serialization protocol automatically 484 """ 485 486 def __init__(self): 487 FramedSerializer.__init__(self) 488 self._type = None 489 490 def dumps(self, obj): 491 if self._type is not None: 492 return b'P' + pickle.dumps(obj, -1) 493 try: 494 return b'M' + marshal.dumps(obj) 495 except Exception: 496 self._type = b'P' 497 return b'P' + pickle.dumps(obj, -1) 498 499 def loads(self, obj): 500 _type = obj[0] 501 if _type == b'M': 502 return marshal.loads(obj[1:]) 503 elif _type == b'P': 504 return pickle.loads(obj[1:]) 505 else: 506 raise ValueError("invalid sevialization type: %s" % _type) 507 508 509class CompressedSerializer(FramedSerializer): 510 """ 511 Compress the serialized data 512 """ 513 def __init__(self, serializer): 514 FramedSerializer.__init__(self) 515 assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer" 516 self.serializer = serializer 517 518 def dumps(self, obj): 519 return zlib.compress(self.serializer.dumps(obj), 1) 520 521 def loads(self, obj): 522 return self.serializer.loads(zlib.decompress(obj)) 523 524 def __repr__(self): 525 return "CompressedSerializer(%s)" % self.serializer 526 527 528class UTF8Deserializer(Serializer): 529 530 """ 531 Deserializes streams written by String.getBytes. 532 """ 533 534 def __init__(self, use_unicode=True): 535 self.use_unicode = use_unicode 536 537 def loads(self, stream): 538 length = read_int(stream) 539 if length == SpecialLengths.END_OF_DATA_SECTION: 540 raise EOFError 541 elif length == SpecialLengths.NULL: 542 return None 543 s = stream.read(length) 544 return s.decode("utf-8") if self.use_unicode else s 545 546 def load_stream(self, stream): 547 try: 548 while True: 549 yield self.loads(stream) 550 except struct.error: 551 return 552 except EOFError: 553 return 554 555 def __repr__(self): 556 return "UTF8Deserializer(%s)" % self.use_unicode 557 558 559def read_long(stream): 560 length = stream.read(8) 561 if not length: 562 raise EOFError 563 return struct.unpack("!q", length)[0] 564 565 566def write_long(value, stream): 567 stream.write(struct.pack("!q", value)) 568 569 570def pack_long(value): 571 return struct.pack("!q", value) 572 573 574def read_int(stream): 575 length = stream.read(4) 576 if not length: 577 raise EOFError 578 return struct.unpack("!i", length)[0] 579 580 581def write_int(value, stream): 582 stream.write(struct.pack("!i", value)) 583 584 585def write_with_length(obj, stream): 586 write_int(len(obj), stream) 587 stream.write(obj) 588 589 590if __name__ == '__main__': 591 import doctest 592 (failure_count, test_count) = doctest.testmod() 593 if failure_count: 594 exit(-1) 595