1#
2# Licensed to the Apache Software Foundation (ASF) under one or more
3# contributor license agreements.  See the NOTICE file distributed with
4# this work for additional information regarding copyright ownership.
5# The ASF licenses this file to You under the Apache License, Version 2.0
6# (the "License"); you may not use this file except in compliance with
7# the License.  You may obtain a copy of the License at
8#
9#    http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""
19PySpark supports custom serializers for transferring data; this can improve
20performance.
21
22By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
23C{cPickle} serializer, which can serialize nearly any Python object.
24Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
25faster.
26
27The serializer is chosen when creating L{SparkContext}:
28
29>>> from pyspark.context import SparkContext
30>>> from pyspark.serializers import MarshalSerializer
31>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
32>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
33[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
34>>> sc.stop()
35
36PySpark serialize objects in batches; By default, the batch size is chosen based
37on the size of objects, also configurable by SparkContext's C{batchSize} parameter:
38
39>>> sc = SparkContext('local', 'test', batchSize=2)
40>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
41
42Behind the scenes, this creates a JavaRDD with four partitions, each of
43which contains two batches of two objects:
44
45>>> rdd.glom().collect()
46[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
47>>> int(rdd._jrdd.count())
488
49>>> sc.stop()
50"""
51
52import sys
53from itertools import chain, product
54import marshal
55import struct
56import types
57import collections
58import zlib
59import itertools
60
61if sys.version < '3':
62    import cPickle as pickle
63    protocol = 2
64    from itertools import izip as zip, imap as map
65else:
66    import pickle
67    protocol = 3
68    xrange = range
69
70from pyspark import cloudpickle
71
72
73__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
74
75
76class SpecialLengths(object):
77    END_OF_DATA_SECTION = -1
78    PYTHON_EXCEPTION_THROWN = -2
79    TIMING_DATA = -3
80    END_OF_STREAM = -4
81    NULL = -5
82
83
84class Serializer(object):
85
86    def dump_stream(self, iterator, stream):
87        """
88        Serialize an iterator of objects to the output stream.
89        """
90        raise NotImplementedError
91
92    def load_stream(self, stream):
93        """
94        Return an iterator of deserialized objects from the input stream.
95        """
96        raise NotImplementedError
97
98    def _load_stream_without_unbatching(self, stream):
99        """
100        Return an iterator of deserialized batches (lists) of objects from the input stream.
101        if the serializer does not operate on batches the default implementation returns an
102        iterator of single element lists.
103        """
104        return map(lambda x: [x], self.load_stream(stream))
105
106    # Note: our notion of "equality" is that output generated by
107    # equal serializers can be deserialized using the same serializer.
108
109    # This default implementation handles the simple cases;
110    # subclasses should override __eq__ as appropriate.
111
112    def __eq__(self, other):
113        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
114
115    def __ne__(self, other):
116        return not self.__eq__(other)
117
118    def __repr__(self):
119        return "%s()" % self.__class__.__name__
120
121    def __hash__(self):
122        return hash(str(self))
123
124
125class FramedSerializer(Serializer):
126
127    """
128    Serializer that writes objects as a stream of (length, data) pairs,
129    where C{length} is a 32-bit integer and data is C{length} bytes.
130    """
131
132    def __init__(self):
133        # On Python 2.6, we can't write bytearrays to streams, so we need to convert them
134        # to strings first. Check if the version number is that old.
135        self._only_write_strings = sys.version_info[0:2] <= (2, 6)
136
137    def dump_stream(self, iterator, stream):
138        for obj in iterator:
139            self._write_with_length(obj, stream)
140
141    def load_stream(self, stream):
142        while True:
143            try:
144                yield self._read_with_length(stream)
145            except EOFError:
146                return
147
148    def _write_with_length(self, obj, stream):
149        serialized = self.dumps(obj)
150        if serialized is None:
151            raise ValueError("serialized value should not be None")
152        if len(serialized) > (1 << 31):
153            raise ValueError("can not serialize object larger than 2G")
154        write_int(len(serialized), stream)
155        if self._only_write_strings:
156            stream.write(str(serialized))
157        else:
158            stream.write(serialized)
159
160    def _read_with_length(self, stream):
161        length = read_int(stream)
162        if length == SpecialLengths.END_OF_DATA_SECTION:
163            raise EOFError
164        elif length == SpecialLengths.NULL:
165            return None
166        obj = stream.read(length)
167        if len(obj) < length:
168            raise EOFError
169        return self.loads(obj)
170
171    def dumps(self, obj):
172        """
173        Serialize an object into a byte array.
174        When batching is used, this will be called with an array of objects.
175        """
176        raise NotImplementedError
177
178    def loads(self, obj):
179        """
180        Deserialize an object from a byte array.
181        """
182        raise NotImplementedError
183
184
185class BatchedSerializer(Serializer):
186
187    """
188    Serializes a stream of objects in batches by calling its wrapped
189    Serializer with streams of objects.
190    """
191
192    UNLIMITED_BATCH_SIZE = -1
193    UNKNOWN_BATCH_SIZE = 0
194
195    def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
196        self.serializer = serializer
197        self.batchSize = batchSize
198
199    def _batched(self, iterator):
200        if self.batchSize == self.UNLIMITED_BATCH_SIZE:
201            yield list(iterator)
202        elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
203            n = len(iterator)
204            for i in xrange(0, n, self.batchSize):
205                yield iterator[i: i + self.batchSize]
206        else:
207            items = []
208            count = 0
209            for item in iterator:
210                items.append(item)
211                count += 1
212                if count == self.batchSize:
213                    yield items
214                    items = []
215                    count = 0
216            if items:
217                yield items
218
219    def dump_stream(self, iterator, stream):
220        self.serializer.dump_stream(self._batched(iterator), stream)
221
222    def load_stream(self, stream):
223        return chain.from_iterable(self._load_stream_without_unbatching(stream))
224
225    def _load_stream_without_unbatching(self, stream):
226        return self.serializer.load_stream(stream)
227
228    def __repr__(self):
229        return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
230
231
232class FlattenedValuesSerializer(BatchedSerializer):
233
234    """
235    Serializes a stream of list of pairs, split the list of values
236    which contain more than a certain number of objects to make them
237    have similar sizes.
238    """
239    def __init__(self, serializer, batchSize=10):
240        BatchedSerializer.__init__(self, serializer, batchSize)
241
242    def _batched(self, iterator):
243        n = self.batchSize
244        for key, values in iterator:
245            for i in range(0, len(values), n):
246                yield key, values[i:i + n]
247
248    def load_stream(self, stream):
249        return self.serializer.load_stream(stream)
250
251    def __repr__(self):
252        return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize)
253
254
255class AutoBatchedSerializer(BatchedSerializer):
256    """
257    Choose the size of batch automatically based on the size of object
258    """
259
260    def __init__(self, serializer, bestSize=1 << 16):
261        BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE)
262        self.bestSize = bestSize
263
264    def dump_stream(self, iterator, stream):
265        batch, best = 1, self.bestSize
266        iterator = iter(iterator)
267        while True:
268            vs = list(itertools.islice(iterator, batch))
269            if not vs:
270                break
271
272            bytes = self.serializer.dumps(vs)
273            write_int(len(bytes), stream)
274            stream.write(bytes)
275
276            size = len(bytes)
277            if size < best:
278                batch *= 2
279            elif size > best * 10 and batch > 1:
280                batch //= 2
281
282    def __repr__(self):
283        return "AutoBatchedSerializer(%s)" % self.serializer
284
285
286class CartesianDeserializer(Serializer):
287
288    """
289    Deserializes the JavaRDD cartesian() of two PythonRDDs.
290    Due to pyspark batching we cannot simply use the result of the Java RDD cartesian,
291    we additionally need to do the cartesian within each pair of batches.
292    """
293
294    def __init__(self, key_ser, val_ser):
295        self.key_ser = key_ser
296        self.val_ser = val_ser
297
298    def _load_stream_without_unbatching(self, stream):
299        key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
300        val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
301        for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
302            # for correctness with repeated cartesian/zip this must be returned as one batch
303            yield product(key_batch, val_batch)
304
305    def load_stream(self, stream):
306        return chain.from_iterable(self._load_stream_without_unbatching(stream))
307
308    def __repr__(self):
309        return "CartesianDeserializer(%s, %s)" % \
310               (str(self.key_ser), str(self.val_ser))
311
312
313class PairDeserializer(Serializer):
314
315    """
316    Deserializes the JavaRDD zip() of two PythonRDDs.
317    Due to pyspark batching we cannot simply use the result of the Java RDD zip,
318    we additionally need to do the zip within each pair of batches.
319    """
320
321    def __init__(self, key_ser, val_ser):
322        self.key_ser = key_ser
323        self.val_ser = val_ser
324
325    def _load_stream_without_unbatching(self, stream):
326        key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
327        val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
328        for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
329            if len(key_batch) != len(val_batch):
330                raise ValueError("Can not deserialize PairRDD with different number of items"
331                                 " in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
332            # for correctness with repeated cartesian/zip this must be returned as one batch
333            yield zip(key_batch, val_batch)
334
335    def load_stream(self, stream):
336        return chain.from_iterable(self._load_stream_without_unbatching(stream))
337
338    def __repr__(self):
339        return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
340
341
342class NoOpSerializer(FramedSerializer):
343
344    def loads(self, obj):
345        return obj
346
347    def dumps(self, obj):
348        return obj
349
350
351# Hook namedtuple, make it picklable
352
353__cls = {}
354
355
356def _restore(name, fields, value):
357    """ Restore an object of namedtuple"""
358    k = (name, fields)
359    cls = __cls.get(k)
360    if cls is None:
361        cls = collections.namedtuple(name, fields)
362        __cls[k] = cls
363    return cls(*value)
364
365
366def _hack_namedtuple(cls):
367    """ Make class generated by namedtuple picklable """
368    name = cls.__name__
369    fields = cls._fields
370
371    def __reduce__(self):
372        return (_restore, (name, fields, tuple(self)))
373    cls.__reduce__ = __reduce__
374    cls._is_namedtuple_ = True
375    return cls
376
377
378def _hijack_namedtuple():
379    """ Hack namedtuple() to make it picklable """
380    # hijack only one time
381    if hasattr(collections.namedtuple, "__hijack"):
382        return
383
384    global _old_namedtuple  # or it will put in closure
385    global _old_namedtuple_kwdefaults  # or it will put in closure too
386
387    def _copy_func(f):
388        return types.FunctionType(f.__code__, f.__globals__, f.__name__,
389                                  f.__defaults__, f.__closure__)
390
391    def _kwdefaults(f):
392        # __kwdefaults__ contains the default values of keyword-only arguments which are
393        # introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple
394        # are as below:
395        #
396        # - Does not exist in Python 2.
397        # - Returns None in <= Python 3.5.x.
398        # - Returns a dictionary containing the default values to the keys from Python 3.6.x
399        #    (See https://bugs.python.org/issue25628).
400        kargs = getattr(f, "__kwdefaults__", None)
401        if kargs is None:
402            return {}
403        else:
404            return kargs
405
406    _old_namedtuple = _copy_func(collections.namedtuple)
407    _old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple)
408
409    def namedtuple(*args, **kwargs):
410        for k, v in _old_namedtuple_kwdefaults.items():
411            kwargs[k] = kwargs.get(k, v)
412        cls = _old_namedtuple(*args, **kwargs)
413        return _hack_namedtuple(cls)
414
415    # replace namedtuple with new one
416    collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults
417    collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
418    collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
419    collections.namedtuple.__code__ = namedtuple.__code__
420    collections.namedtuple.__hijack = 1
421
422    # hack the cls already generated by namedtuple
423    # those created in other module can be pickled as normal,
424    # so only hack those in __main__ module
425    for n, o in sys.modules["__main__"].__dict__.items():
426        if (type(o) is type and o.__base__ is tuple
427                and hasattr(o, "_fields")
428                and "__reduce__" not in o.__dict__):
429            _hack_namedtuple(o)  # hack inplace
430
431
432_hijack_namedtuple()
433
434
435class PickleSerializer(FramedSerializer):
436
437    """
438    Serializes objects using Python's pickle serializer:
439
440        http://docs.python.org/2/library/pickle.html
441
442    This serializer supports nearly any Python object, but may
443    not be as fast as more specialized serializers.
444    """
445
446    def dumps(self, obj):
447        return pickle.dumps(obj, protocol)
448
449    if sys.version >= '3':
450        def loads(self, obj, encoding="bytes"):
451            return pickle.loads(obj, encoding=encoding)
452    else:
453        def loads(self, obj, encoding=None):
454            return pickle.loads(obj)
455
456
457class CloudPickleSerializer(PickleSerializer):
458
459    def dumps(self, obj):
460        return cloudpickle.dumps(obj, 2)
461
462
463class MarshalSerializer(FramedSerializer):
464
465    """
466    Serializes objects using Python's Marshal serializer:
467
468        http://docs.python.org/2/library/marshal.html
469
470    This serializer is faster than PickleSerializer but supports fewer datatypes.
471    """
472
473    def dumps(self, obj):
474        return marshal.dumps(obj)
475
476    def loads(self, obj):
477        return marshal.loads(obj)
478
479
480class AutoSerializer(FramedSerializer):
481
482    """
483    Choose marshal or pickle as serialization protocol automatically
484    """
485
486    def __init__(self):
487        FramedSerializer.__init__(self)
488        self._type = None
489
490    def dumps(self, obj):
491        if self._type is not None:
492            return b'P' + pickle.dumps(obj, -1)
493        try:
494            return b'M' + marshal.dumps(obj)
495        except Exception:
496            self._type = b'P'
497            return b'P' + pickle.dumps(obj, -1)
498
499    def loads(self, obj):
500        _type = obj[0]
501        if _type == b'M':
502            return marshal.loads(obj[1:])
503        elif _type == b'P':
504            return pickle.loads(obj[1:])
505        else:
506            raise ValueError("invalid sevialization type: %s" % _type)
507
508
509class CompressedSerializer(FramedSerializer):
510    """
511    Compress the serialized data
512    """
513    def __init__(self, serializer):
514        FramedSerializer.__init__(self)
515        assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer"
516        self.serializer = serializer
517
518    def dumps(self, obj):
519        return zlib.compress(self.serializer.dumps(obj), 1)
520
521    def loads(self, obj):
522        return self.serializer.loads(zlib.decompress(obj))
523
524    def __repr__(self):
525        return "CompressedSerializer(%s)" % self.serializer
526
527
528class UTF8Deserializer(Serializer):
529
530    """
531    Deserializes streams written by String.getBytes.
532    """
533
534    def __init__(self, use_unicode=True):
535        self.use_unicode = use_unicode
536
537    def loads(self, stream):
538        length = read_int(stream)
539        if length == SpecialLengths.END_OF_DATA_SECTION:
540            raise EOFError
541        elif length == SpecialLengths.NULL:
542            return None
543        s = stream.read(length)
544        return s.decode("utf-8") if self.use_unicode else s
545
546    def load_stream(self, stream):
547        try:
548            while True:
549                yield self.loads(stream)
550        except struct.error:
551            return
552        except EOFError:
553            return
554
555    def __repr__(self):
556        return "UTF8Deserializer(%s)" % self.use_unicode
557
558
559def read_long(stream):
560    length = stream.read(8)
561    if not length:
562        raise EOFError
563    return struct.unpack("!q", length)[0]
564
565
566def write_long(value, stream):
567    stream.write(struct.pack("!q", value))
568
569
570def pack_long(value):
571    return struct.pack("!q", value)
572
573
574def read_int(stream):
575    length = stream.read(4)
576    if not length:
577        raise EOFError
578    return struct.unpack("!i", length)[0]
579
580
581def write_int(value, stream):
582    stream.write(struct.pack("!i", value))
583
584
585def write_with_length(obj, stream):
586    write_int(len(obj), stream)
587    stream.write(obj)
588
589
590if __name__ == '__main__':
591    import doctest
592    (failure_count, test_count) = doctest.testmod()
593    if failure_count:
594        exit(-1)
595