1#
2# Licensed to the Apache Software Foundation (ASF) under one or more
3# contributor license agreements.  See the NOTICE file distributed with
4# this work for additional information regarding copyright ownership.
5# The ASF licenses this file to You under the Apache License, Version 2.0
6# (the "License"); you may not use this file except in compliance with
7# the License.  You may obtain a copy of the License at
8#
9#    http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""
19Unit tests for PySpark; additional tests are implemented as doctests in
20individual modules.
21"""
22
23from array import array
24from glob import glob
25import os
26import re
27import shutil
28import subprocess
29import sys
30import tempfile
31import time
32import zipfile
33import random
34import threading
35import hashlib
36
37from py4j.protocol import Py4JJavaError
38try:
39    import xmlrunner
40except ImportError:
41    xmlrunner = None
42
43if sys.version_info[:2] <= (2, 6):
44    try:
45        import unittest2 as unittest
46    except ImportError:
47        sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
48        sys.exit(1)
49else:
50    import unittest
51    if sys.version_info[0] >= 3:
52        xrange = range
53        basestring = str
54
55if sys.version >= "3":
56    from io import StringIO
57else:
58    from StringIO import StringIO
59
60
61from pyspark import keyword_only
62from pyspark.conf import SparkConf
63from pyspark.context import SparkContext
64from pyspark.rdd import RDD
65from pyspark.files import SparkFiles
66from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
67    CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \
68    PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \
69    FlattenedValuesSerializer
70from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
71from pyspark import shuffle
72from pyspark.profiler import BasicProfiler
73
74_have_scipy = False
75_have_numpy = False
76try:
77    import scipy.sparse
78    _have_scipy = True
79except:
80    # No SciPy, but that's okay, we'll skip those tests
81    pass
82try:
83    import numpy as np
84    _have_numpy = True
85except:
86    # No NumPy, but that's okay, we'll skip those tests
87    pass
88
89
90SPARK_HOME = os.environ["SPARK_HOME"]
91
92
93class MergerTests(unittest.TestCase):
94
95    def setUp(self):
96        self.N = 1 << 12
97        self.l = [i for i in xrange(self.N)]
98        self.data = list(zip(self.l, self.l))
99        self.agg = Aggregator(lambda x: [x],
100                              lambda x, y: x.append(y) or x,
101                              lambda x, y: x.extend(y) or x)
102
103    def test_small_dataset(self):
104        m = ExternalMerger(self.agg, 1000)
105        m.mergeValues(self.data)
106        self.assertEqual(m.spills, 0)
107        self.assertEqual(sum(sum(v) for k, v in m.items()),
108                         sum(xrange(self.N)))
109
110        m = ExternalMerger(self.agg, 1000)
111        m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data))
112        self.assertEqual(m.spills, 0)
113        self.assertEqual(sum(sum(v) for k, v in m.items()),
114                         sum(xrange(self.N)))
115
116    def test_medium_dataset(self):
117        m = ExternalMerger(self.agg, 20)
118        m.mergeValues(self.data)
119        self.assertTrue(m.spills >= 1)
120        self.assertEqual(sum(sum(v) for k, v in m.items()),
121                         sum(xrange(self.N)))
122
123        m = ExternalMerger(self.agg, 10)
124        m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3))
125        self.assertTrue(m.spills >= 1)
126        self.assertEqual(sum(sum(v) for k, v in m.items()),
127                         sum(xrange(self.N)) * 3)
128
129    def test_huge_dataset(self):
130        m = ExternalMerger(self.agg, 5, partitions=3)
131        m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
132        self.assertTrue(m.spills >= 1)
133        self.assertEqual(sum(len(v) for k, v in m.items()),
134                         self.N * 10)
135        m._cleanup()
136
137    def test_group_by_key(self):
138
139        def gen_data(N, step):
140            for i in range(1, N + 1, step):
141                for j in range(i):
142                    yield (i, [j])
143
144        def gen_gs(N, step=1):
145            return shuffle.GroupByKey(gen_data(N, step))
146
147        self.assertEqual(1, len(list(gen_gs(1))))
148        self.assertEqual(2, len(list(gen_gs(2))))
149        self.assertEqual(100, len(list(gen_gs(100))))
150        self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)])
151        self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100)))
152
153        for k, vs in gen_gs(50002, 10000):
154            self.assertEqual(k, len(vs))
155            self.assertEqual(list(range(k)), list(vs))
156
157        ser = PickleSerializer()
158        l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
159        for k, vs in l:
160            self.assertEqual(k, len(vs))
161            self.assertEqual(list(range(k)), list(vs))
162
163
164class SorterTests(unittest.TestCase):
165    def test_in_memory_sort(self):
166        l = list(range(1024))
167        random.shuffle(l)
168        sorter = ExternalSorter(1024)
169        self.assertEqual(sorted(l), list(sorter.sorted(l)))
170        self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
171        self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
172        self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
173                         list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
174
175    def test_external_sort(self):
176        class CustomizedSorter(ExternalSorter):
177            def _next_limit(self):
178                return self.memory_limit
179        l = list(range(1024))
180        random.shuffle(l)
181        sorter = CustomizedSorter(1)
182        self.assertEqual(sorted(l), list(sorter.sorted(l)))
183        self.assertGreater(shuffle.DiskBytesSpilled, 0)
184        last = shuffle.DiskBytesSpilled
185        self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
186        self.assertGreater(shuffle.DiskBytesSpilled, last)
187        last = shuffle.DiskBytesSpilled
188        self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
189        self.assertGreater(shuffle.DiskBytesSpilled, last)
190        last = shuffle.DiskBytesSpilled
191        self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
192                         list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
193        self.assertGreater(shuffle.DiskBytesSpilled, last)
194
195    def test_external_sort_in_rdd(self):
196        conf = SparkConf().set("spark.python.worker.memory", "1m")
197        sc = SparkContext(conf=conf)
198        l = list(range(10240))
199        random.shuffle(l)
200        rdd = sc.parallelize(l, 4)
201        self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
202        sc.stop()
203
204
205class SerializationTestCase(unittest.TestCase):
206
207    def test_namedtuple(self):
208        from collections import namedtuple
209        from pickle import dumps, loads
210        P = namedtuple("P", "x y")
211        p1 = P(1, 3)
212        p2 = loads(dumps(p1, 2))
213        self.assertEqual(p1, p2)
214
215        from pyspark.cloudpickle import dumps
216        P2 = loads(dumps(P))
217        p3 = P2(1, 3)
218        self.assertEqual(p1, p3)
219
220    def test_itemgetter(self):
221        from operator import itemgetter
222        ser = CloudPickleSerializer()
223        d = range(10)
224        getter = itemgetter(1)
225        getter2 = ser.loads(ser.dumps(getter))
226        self.assertEqual(getter(d), getter2(d))
227
228        getter = itemgetter(0, 3)
229        getter2 = ser.loads(ser.dumps(getter))
230        self.assertEqual(getter(d), getter2(d))
231
232    def test_function_module_name(self):
233        ser = CloudPickleSerializer()
234        func = lambda x: x
235        func2 = ser.loads(ser.dumps(func))
236        self.assertEqual(func.__module__, func2.__module__)
237
238    def test_attrgetter(self):
239        from operator import attrgetter
240        ser = CloudPickleSerializer()
241
242        class C(object):
243            def __getattr__(self, item):
244                return item
245        d = C()
246        getter = attrgetter("a")
247        getter2 = ser.loads(ser.dumps(getter))
248        self.assertEqual(getter(d), getter2(d))
249        getter = attrgetter("a", "b")
250        getter2 = ser.loads(ser.dumps(getter))
251        self.assertEqual(getter(d), getter2(d))
252
253        d.e = C()
254        getter = attrgetter("e.a")
255        getter2 = ser.loads(ser.dumps(getter))
256        self.assertEqual(getter(d), getter2(d))
257        getter = attrgetter("e.a", "e.b")
258        getter2 = ser.loads(ser.dumps(getter))
259        self.assertEqual(getter(d), getter2(d))
260
261    # Regression test for SPARK-3415
262    def test_pickling_file_handles(self):
263        # to be corrected with SPARK-11160
264        if not xmlrunner:
265            ser = CloudPickleSerializer()
266            out1 = sys.stderr
267            out2 = ser.loads(ser.dumps(out1))
268            self.assertEqual(out1, out2)
269
270    def test_func_globals(self):
271
272        class Unpicklable(object):
273            def __reduce__(self):
274                raise Exception("not picklable")
275
276        global exit
277        exit = Unpicklable()
278
279        ser = CloudPickleSerializer()
280        self.assertRaises(Exception, lambda: ser.dumps(exit))
281
282        def foo():
283            sys.exit(0)
284
285        self.assertTrue("exit" in foo.__code__.co_names)
286        ser.dumps(foo)
287
288    def test_compressed_serializer(self):
289        ser = CompressedSerializer(PickleSerializer())
290        try:
291            from StringIO import StringIO
292        except ImportError:
293            from io import BytesIO as StringIO
294        io = StringIO()
295        ser.dump_stream(["abc", u"123", range(5)], io)
296        io.seek(0)
297        self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
298        ser.dump_stream(range(1000), io)
299        io.seek(0)
300        self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io)))
301        io.close()
302
303    def test_hash_serializer(self):
304        hash(NoOpSerializer())
305        hash(UTF8Deserializer())
306        hash(PickleSerializer())
307        hash(MarshalSerializer())
308        hash(AutoSerializer())
309        hash(BatchedSerializer(PickleSerializer()))
310        hash(AutoBatchedSerializer(MarshalSerializer()))
311        hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
312        hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
313        hash(CompressedSerializer(PickleSerializer()))
314        hash(FlattenedValuesSerializer(PickleSerializer()))
315
316
317class QuietTest(object):
318    def __init__(self, sc):
319        self.log4j = sc._jvm.org.apache.log4j
320
321    def __enter__(self):
322        self.old_level = self.log4j.LogManager.getRootLogger().getLevel()
323        self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL)
324
325    def __exit__(self, exc_type, exc_val, exc_tb):
326        self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
327
328
329class PySparkTestCase(unittest.TestCase):
330
331    def setUp(self):
332        self._old_sys_path = list(sys.path)
333        class_name = self.__class__.__name__
334        self.sc = SparkContext('local[4]', class_name)
335
336    def tearDown(self):
337        self.sc.stop()
338        sys.path = self._old_sys_path
339
340
341class ReusedPySparkTestCase(unittest.TestCase):
342
343    @classmethod
344    def setUpClass(cls):
345        cls.sc = SparkContext('local[4]', cls.__name__)
346
347    @classmethod
348    def tearDownClass(cls):
349        cls.sc.stop()
350
351
352class CheckpointTests(ReusedPySparkTestCase):
353
354    def setUp(self):
355        self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
356        os.unlink(self.checkpointDir.name)
357        self.sc.setCheckpointDir(self.checkpointDir.name)
358
359    def tearDown(self):
360        shutil.rmtree(self.checkpointDir.name)
361
362    def test_basic_checkpointing(self):
363        parCollection = self.sc.parallelize([1, 2, 3, 4])
364        flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
365
366        self.assertFalse(flatMappedRDD.isCheckpointed())
367        self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
368
369        flatMappedRDD.checkpoint()
370        result = flatMappedRDD.collect()
371        time.sleep(1)  # 1 second
372        self.assertTrue(flatMappedRDD.isCheckpointed())
373        self.assertEqual(flatMappedRDD.collect(), result)
374        self.assertEqual("file:" + self.checkpointDir.name,
375                         os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile())))
376
377    def test_checkpoint_and_restore(self):
378        parCollection = self.sc.parallelize([1, 2, 3, 4])
379        flatMappedRDD = parCollection.flatMap(lambda x: [x])
380
381        self.assertFalse(flatMappedRDD.isCheckpointed())
382        self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
383
384        flatMappedRDD.checkpoint()
385        flatMappedRDD.count()  # forces a checkpoint to be computed
386        time.sleep(1)  # 1 second
387
388        self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
389        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
390                                            flatMappedRDD._jrdd_deserializer)
391        self.assertEqual([1, 2, 3, 4], recovered.collect())
392
393
394class LocalCheckpointTests(ReusedPySparkTestCase):
395
396    def test_basic_localcheckpointing(self):
397        parCollection = self.sc.parallelize([1, 2, 3, 4])
398        flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
399
400        self.assertFalse(flatMappedRDD.isCheckpointed())
401        self.assertFalse(flatMappedRDD.isLocallyCheckpointed())
402
403        flatMappedRDD.localCheckpoint()
404        result = flatMappedRDD.collect()
405        time.sleep(1)  # 1 second
406        self.assertTrue(flatMappedRDD.isCheckpointed())
407        self.assertTrue(flatMappedRDD.isLocallyCheckpointed())
408        self.assertEqual(flatMappedRDD.collect(), result)
409
410
411class AddFileTests(PySparkTestCase):
412
413    def test_add_py_file(self):
414        # To ensure that we're actually testing addPyFile's effects, check that
415        # this job fails due to `userlibrary` not being on the Python path:
416        # disable logging in log4j temporarily
417        def func(x):
418            from userlibrary import UserClass
419            return UserClass().hello()
420        with QuietTest(self.sc):
421            self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first)
422
423        # Add the file, so the job should now succeed:
424        path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
425        self.sc.addPyFile(path)
426        res = self.sc.parallelize(range(2)).map(func).first()
427        self.assertEqual("Hello World!", res)
428
429    def test_add_file_locally(self):
430        path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
431        self.sc.addFile(path)
432        download_path = SparkFiles.get("hello.txt")
433        self.assertNotEqual(path, download_path)
434        with open(download_path) as test_file:
435            self.assertEqual("Hello World!\n", test_file.readline())
436
437    def test_add_file_recursively_locally(self):
438        path = os.path.join(SPARK_HOME, "python/test_support/hello")
439        self.sc.addFile(path, True)
440        download_path = SparkFiles.get("hello")
441        self.assertNotEqual(path, download_path)
442        with open(download_path + "/hello.txt") as test_file:
443            self.assertEqual("Hello World!\n", test_file.readline())
444        with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
445            self.assertEqual("Sub Hello World!\n", test_file.readline())
446
447    def test_add_py_file_locally(self):
448        # To ensure that we're actually testing addPyFile's effects, check that
449        # this fails due to `userlibrary` not being on the Python path:
450        def func():
451            from userlibrary import UserClass
452        self.assertRaises(ImportError, func)
453        path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
454        self.sc.addPyFile(path)
455        from userlibrary import UserClass
456        self.assertEqual("Hello World!", UserClass().hello())
457
458    def test_add_egg_file_locally(self):
459        # To ensure that we're actually testing addPyFile's effects, check that
460        # this fails due to `userlibrary` not being on the Python path:
461        def func():
462            from userlib import UserClass
463        self.assertRaises(ImportError, func)
464        path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip")
465        self.sc.addPyFile(path)
466        from userlib import UserClass
467        self.assertEqual("Hello World from inside a package!", UserClass().hello())
468
469    def test_overwrite_system_module(self):
470        self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))
471
472        import SimpleHTTPServer
473        self.assertEqual("My Server", SimpleHTTPServer.__name__)
474
475        def func(x):
476            import SimpleHTTPServer
477            return SimpleHTTPServer.__name__
478
479        self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
480
481
482class RDDTests(ReusedPySparkTestCase):
483
484    def test_range(self):
485        self.assertEqual(self.sc.range(1, 1).count(), 0)
486        self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
487        self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
488
489    def test_id(self):
490        rdd = self.sc.parallelize(range(10))
491        id = rdd.id()
492        self.assertEqual(id, rdd.id())
493        rdd2 = rdd.map(str).filter(bool)
494        id2 = rdd2.id()
495        self.assertEqual(id + 1, id2)
496        self.assertEqual(id2, rdd2.id())
497
498    def test_empty_rdd(self):
499        rdd = self.sc.emptyRDD()
500        self.assertTrue(rdd.isEmpty())
501
502    def test_sum(self):
503        self.assertEqual(0, self.sc.emptyRDD().sum())
504        self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
505
506    def test_to_localiterator(self):
507        from time import sleep
508        rdd = self.sc.parallelize([1, 2, 3])
509        it = rdd.toLocalIterator()
510        sleep(5)
511        self.assertEqual([1, 2, 3], sorted(it))
512
513        rdd2 = rdd.repartition(1000)
514        it2 = rdd2.toLocalIterator()
515        sleep(5)
516        self.assertEqual([1, 2, 3], sorted(it2))
517
518    def test_save_as_textfile_with_unicode(self):
519        # Regression test for SPARK-970
520        x = u"\u00A1Hola, mundo!"
521        data = self.sc.parallelize([x])
522        tempFile = tempfile.NamedTemporaryFile(delete=True)
523        tempFile.close()
524        data.saveAsTextFile(tempFile.name)
525        raw_contents = b''.join(open(p, 'rb').read()
526                                for p in glob(tempFile.name + "/part-0000*"))
527        self.assertEqual(x, raw_contents.strip().decode("utf-8"))
528
529    def test_save_as_textfile_with_utf8(self):
530        x = u"\u00A1Hola, mundo!"
531        data = self.sc.parallelize([x.encode("utf-8")])
532        tempFile = tempfile.NamedTemporaryFile(delete=True)
533        tempFile.close()
534        data.saveAsTextFile(tempFile.name)
535        raw_contents = b''.join(open(p, 'rb').read()
536                                for p in glob(tempFile.name + "/part-0000*"))
537        self.assertEqual(x, raw_contents.strip().decode('utf8'))
538
539    def test_transforming_cartesian_result(self):
540        # Regression test for SPARK-1034
541        rdd1 = self.sc.parallelize([1, 2])
542        rdd2 = self.sc.parallelize([3, 4])
543        cart = rdd1.cartesian(rdd2)
544        result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect()
545
546    def test_transforming_pickle_file(self):
547        # Regression test for SPARK-2601
548        data = self.sc.parallelize([u"Hello", u"World!"])
549        tempFile = tempfile.NamedTemporaryFile(delete=True)
550        tempFile.close()
551        data.saveAsPickleFile(tempFile.name)
552        pickled_file = self.sc.pickleFile(tempFile.name)
553        pickled_file.map(lambda x: x).collect()
554
555    def test_cartesian_on_textfile(self):
556        # Regression test for
557        path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
558        a = self.sc.textFile(path)
559        result = a.cartesian(a).collect()
560        (x, y) = result[0]
561        self.assertEqual(u"Hello World!", x.strip())
562        self.assertEqual(u"Hello World!", y.strip())
563
564    def test_cartesian_chaining(self):
565        # Tests for SPARK-16589
566        rdd = self.sc.parallelize(range(10), 2)
567        self.assertSetEqual(
568            set(rdd.cartesian(rdd).cartesian(rdd).collect()),
569            set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)])
570        )
571
572        self.assertSetEqual(
573            set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
574            set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
575        )
576
577        self.assertSetEqual(
578            set(rdd.cartesian(rdd.zip(rdd)).collect()),
579            set([(x, (y, y)) for x in range(10) for y in range(10)])
580        )
581
582    def test_deleting_input_files(self):
583        # Regression test for SPARK-1025
584        tempFile = tempfile.NamedTemporaryFile(delete=False)
585        tempFile.write(b"Hello World!")
586        tempFile.close()
587        data = self.sc.textFile(tempFile.name)
588        filtered_data = data.filter(lambda x: True)
589        self.assertEqual(1, filtered_data.count())
590        os.unlink(tempFile.name)
591        with QuietTest(self.sc):
592            self.assertRaises(Exception, lambda: filtered_data.count())
593
594    def test_sampling_default_seed(self):
595        # Test for SPARK-3995 (default seed setting)
596        data = self.sc.parallelize(xrange(1000), 1)
597        subset = data.takeSample(False, 10)
598        self.assertEqual(len(subset), 10)
599
600    def test_aggregate_mutable_zero_value(self):
601        # Test for SPARK-9021; uses aggregate and treeAggregate to build dict
602        # representing a counter of ints
603        # NOTE: dict is used instead of collections.Counter for Python 2.6
604        # compatibility
605        from collections import defaultdict
606
607        # Show that single or multiple partitions work
608        data1 = self.sc.range(10, numSlices=1)
609        data2 = self.sc.range(10, numSlices=2)
610
611        def seqOp(x, y):
612            x[y] += 1
613            return x
614
615        def comboOp(x, y):
616            for key, val in y.items():
617                x[key] += val
618            return x
619
620        counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
621        counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
622        counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
623        counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
624
625        ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
626        self.assertEqual(counts1, ground_truth)
627        self.assertEqual(counts2, ground_truth)
628        self.assertEqual(counts3, ground_truth)
629        self.assertEqual(counts4, ground_truth)
630
631    def test_aggregate_by_key_mutable_zero_value(self):
632        # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
633        # contains lists of all values for each key in the original RDD
634
635        # list(range(...)) for Python 3.x compatibility (can't use * operator
636        # on a range object)
637        # list(zip(...)) for Python 3.x compatibility (want to parallelize a
638        # collection, not a zip object)
639        tuples = list(zip(list(range(10))*2, [1]*20))
640        # Show that single or multiple partitions work
641        data1 = self.sc.parallelize(tuples, 1)
642        data2 = self.sc.parallelize(tuples, 2)
643
644        def seqOp(x, y):
645            x.append(y)
646            return x
647
648        def comboOp(x, y):
649            x.extend(y)
650            return x
651
652        values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
653        values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
654        # Sort lists to ensure clean comparison with ground_truth
655        values1.sort()
656        values2.sort()
657
658        ground_truth = [(i, [1]*2) for i in range(10)]
659        self.assertEqual(values1, ground_truth)
660        self.assertEqual(values2, ground_truth)
661
662    def test_fold_mutable_zero_value(self):
663        # Test for SPARK-9021; uses fold to merge an RDD of dict counters into
664        # a single dict
665        # NOTE: dict is used instead of collections.Counter for Python 2.6
666        # compatibility
667        from collections import defaultdict
668
669        counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
670        counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
671        counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
672        counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
673        all_counts = [counts1, counts2, counts3, counts4]
674        # Show that single or multiple partitions work
675        data1 = self.sc.parallelize(all_counts, 1)
676        data2 = self.sc.parallelize(all_counts, 2)
677
678        def comboOp(x, y):
679            for key, val in y.items():
680                x[key] += val
681            return x
682
683        fold1 = data1.fold(defaultdict(int), comboOp)
684        fold2 = data2.fold(defaultdict(int), comboOp)
685
686        ground_truth = defaultdict(int)
687        for counts in all_counts:
688            for key, val in counts.items():
689                ground_truth[key] += val
690        self.assertEqual(fold1, ground_truth)
691        self.assertEqual(fold2, ground_truth)
692
693    def test_fold_by_key_mutable_zero_value(self):
694        # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
695        # lists of all values for each key in the original RDD
696
697        tuples = [(i, range(i)) for i in range(10)]*2
698        # Show that single or multiple partitions work
699        data1 = self.sc.parallelize(tuples, 1)
700        data2 = self.sc.parallelize(tuples, 2)
701
702        def comboOp(x, y):
703            x.extend(y)
704            return x
705
706        values1 = data1.foldByKey([], comboOp).collect()
707        values2 = data2.foldByKey([], comboOp).collect()
708        # Sort lists to ensure clean comparison with ground_truth
709        values1.sort()
710        values2.sort()
711
712        # list(range(...)) for Python 3.x compatibility
713        ground_truth = [(i, list(range(i))*2) for i in range(10)]
714        self.assertEqual(values1, ground_truth)
715        self.assertEqual(values2, ground_truth)
716
717    def test_aggregate_by_key(self):
718        data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
719
720        def seqOp(x, y):
721            x.add(y)
722            return x
723
724        def combOp(x, y):
725            x |= y
726            return x
727
728        sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
729        self.assertEqual(3, len(sets))
730        self.assertEqual(set([1]), sets[1])
731        self.assertEqual(set([2]), sets[3])
732        self.assertEqual(set([1, 3]), sets[5])
733
734    def test_itemgetter(self):
735        rdd = self.sc.parallelize([range(10)])
736        from operator import itemgetter
737        self.assertEqual([1], rdd.map(itemgetter(1)).collect())
738        self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
739
740    def test_namedtuple_in_rdd(self):
741        from collections import namedtuple
742        Person = namedtuple("Person", "id firstName lastName")
743        jon = Person(1, "Jon", "Doe")
744        jane = Person(2, "Jane", "Doe")
745        theDoes = self.sc.parallelize([jon, jane])
746        self.assertEqual([jon, jane], theDoes.collect())
747
748    def test_large_broadcast(self):
749        N = 10000
750        data = [[float(i) for i in range(300)] for i in range(N)]
751        bdata = self.sc.broadcast(data)  # 27MB
752        m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
753        self.assertEqual(N, m)
754
755    def test_unpersist(self):
756        N = 1000
757        data = [[float(i) for i in range(300)] for i in range(N)]
758        bdata = self.sc.broadcast(data)  # 3MB
759        bdata.unpersist()
760        m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
761        self.assertEqual(N, m)
762        bdata.destroy()
763        try:
764            self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
765        except Exception as e:
766            pass
767        else:
768            raise Exception("job should fail after destroy the broadcast")
769
770    def test_multiple_broadcasts(self):
771        N = 1 << 21
772        b1 = self.sc.broadcast(set(range(N)))  # multiple blocks in JVM
773        r = list(range(1 << 15))
774        random.shuffle(r)
775        s = str(r).encode()
776        checksum = hashlib.md5(s).hexdigest()
777        b2 = self.sc.broadcast(s)
778        r = list(set(self.sc.parallelize(range(10), 10).map(
779            lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
780        self.assertEqual(1, len(r))
781        size, csum = r[0]
782        self.assertEqual(N, size)
783        self.assertEqual(checksum, csum)
784
785        random.shuffle(r)
786        s = str(r).encode()
787        checksum = hashlib.md5(s).hexdigest()
788        b2 = self.sc.broadcast(s)
789        r = list(set(self.sc.parallelize(range(10), 10).map(
790            lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
791        self.assertEqual(1, len(r))
792        size, csum = r[0]
793        self.assertEqual(N, size)
794        self.assertEqual(checksum, csum)
795
796    def test_large_closure(self):
797        N = 200000
798        data = [float(i) for i in xrange(N)]
799        rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
800        self.assertEqual(N, rdd.first())
801        # regression test for SPARK-6886
802        self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
803
804    def test_zip_with_different_serializers(self):
805        a = self.sc.parallelize(range(5))
806        b = self.sc.parallelize(range(100, 105))
807        self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
808        a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
809        b = b._reserialize(MarshalSerializer())
810        self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
811        # regression test for SPARK-4841
812        path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
813        t = self.sc.textFile(path)
814        cnt = t.count()
815        self.assertEqual(cnt, t.zip(t).count())
816        rdd = t.map(str)
817        self.assertEqual(cnt, t.zip(rdd).count())
818        # regression test for bug in _reserializer()
819        self.assertEqual(cnt, t.zip(rdd).count())
820
821    def test_zip_with_different_object_sizes(self):
822        # regress test for SPARK-5973
823        a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
824        b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
825        self.assertEqual(10000, a.zip(b).count())
826
827    def test_zip_with_different_number_of_items(self):
828        a = self.sc.parallelize(range(5), 2)
829        # different number of partitions
830        b = self.sc.parallelize(range(100, 106), 3)
831        self.assertRaises(ValueError, lambda: a.zip(b))
832        with QuietTest(self.sc):
833            # different number of batched items in JVM
834            b = self.sc.parallelize(range(100, 104), 2)
835            self.assertRaises(Exception, lambda: a.zip(b).count())
836            # different number of items in one pair
837            b = self.sc.parallelize(range(100, 106), 2)
838            self.assertRaises(Exception, lambda: a.zip(b).count())
839            # same total number of items, but different distributions
840            a = self.sc.parallelize([2, 3], 2).flatMap(range)
841            b = self.sc.parallelize([3, 2], 2).flatMap(range)
842            self.assertEqual(a.count(), b.count())
843            self.assertRaises(Exception, lambda: a.zip(b).count())
844
845    def test_count_approx_distinct(self):
846        rdd = self.sc.parallelize(xrange(1000))
847        self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
848        self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
849        self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
850        self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050)
851
852        rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
853        self.assertTrue(18 < rdd.countApproxDistinct() < 22)
854        self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
855        self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
856        self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)
857
858        self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
859
860    def test_histogram(self):
861        # empty
862        rdd = self.sc.parallelize([])
863        self.assertEqual([0], rdd.histogram([0, 10])[1])
864        self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
865        self.assertRaises(ValueError, lambda: rdd.histogram(1))
866
867        # out of range
868        rdd = self.sc.parallelize([10.01, -0.01])
869        self.assertEqual([0], rdd.histogram([0, 10])[1])
870        self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1])
871
872        # in range with one bucket
873        rdd = self.sc.parallelize(range(1, 5))
874        self.assertEqual([4], rdd.histogram([0, 10])[1])
875        self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1])
876
877        # in range with one bucket exact match
878        self.assertEqual([4], rdd.histogram([1, 4])[1])
879
880        # out of range with two buckets
881        rdd = self.sc.parallelize([10.01, -0.01])
882        self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1])
883
884        # out of range with two uneven buckets
885        rdd = self.sc.parallelize([10.01, -0.01])
886        self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
887
888        # in range with two buckets
889        rdd = self.sc.parallelize([1, 2, 3, 5, 6])
890        self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
891
892        # in range with two bucket and None
893        rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
894        self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
895
896        # in range with two uneven buckets
897        rdd = self.sc.parallelize([1, 2, 3, 5, 6])
898        self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1])
899
900        # mixed range with two uneven buckets
901        rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01])
902        self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1])
903
904        # mixed range with four uneven buckets
905        rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1])
906        self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
907
908        # mixed range with uneven buckets and NaN
909        rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0,
910                                   199.0, 200.0, 200.1, None, float('nan')])
911        self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
912
913        # out of range with infinite buckets
914        rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")])
915        self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
916
917        # invalid buckets
918        self.assertRaises(ValueError, lambda: rdd.histogram([]))
919        self.assertRaises(ValueError, lambda: rdd.histogram([1]))
920        self.assertRaises(ValueError, lambda: rdd.histogram(0))
921        self.assertRaises(TypeError, lambda: rdd.histogram({}))
922
923        # without buckets
924        rdd = self.sc.parallelize(range(1, 5))
925        self.assertEqual(([1, 4], [4]), rdd.histogram(1))
926
927        # without buckets single element
928        rdd = self.sc.parallelize([1])
929        self.assertEqual(([1, 1], [1]), rdd.histogram(1))
930
931        # without bucket no range
932        rdd = self.sc.parallelize([1] * 4)
933        self.assertEqual(([1, 1], [4]), rdd.histogram(1))
934
935        # without buckets basic two
936        rdd = self.sc.parallelize(range(1, 5))
937        self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
938
939        # without buckets with more requested than elements
940        rdd = self.sc.parallelize([1, 2])
941        buckets = [1 + 0.2 * i for i in range(6)]
942        hist = [1, 0, 0, 0, 1]
943        self.assertEqual((buckets, hist), rdd.histogram(5))
944
945        # invalid RDDs
946        rdd = self.sc.parallelize([1, float('inf')])
947        self.assertRaises(ValueError, lambda: rdd.histogram(2))
948        rdd = self.sc.parallelize([float('nan')])
949        self.assertRaises(ValueError, lambda: rdd.histogram(2))
950
951        # string
952        rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2)
953        self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1])
954        self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
955        self.assertRaises(TypeError, lambda: rdd.histogram(2))
956
957    def test_repartitionAndSortWithinPartitions(self):
958        rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
959
960        repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
961        partitions = repartitioned.glom().collect()
962        self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
963        self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
964
965    def test_repartition_no_skewed(self):
966        num_partitions = 20
967        a = self.sc.parallelize(range(int(1000)), 2)
968        l = a.repartition(num_partitions).glom().map(len).collect()
969        zeros = len([x for x in l if x == 0])
970        self.assertTrue(zeros == 0)
971        l = a.coalesce(num_partitions, True).glom().map(len).collect()
972        zeros = len([x for x in l if x == 0])
973        self.assertTrue(zeros == 0)
974
975    def test_repartition_on_textfile(self):
976        path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
977        rdd = self.sc.textFile(path)
978        result = rdd.repartition(1).collect()
979        self.assertEqual(u"Hello World!", result[0])
980
981    def test_distinct(self):
982        rdd = self.sc.parallelize((1, 2, 3)*10, 10)
983        self.assertEqual(rdd.getNumPartitions(), 10)
984        self.assertEqual(rdd.distinct().count(), 3)
985        result = rdd.distinct(5)
986        self.assertEqual(result.getNumPartitions(), 5)
987        self.assertEqual(result.count(), 3)
988
989    def test_external_group_by_key(self):
990        self.sc._conf.set("spark.python.worker.memory", "1m")
991        N = 200001
992        kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
993        gkv = kv.groupByKey().cache()
994        self.assertEqual(3, gkv.count())
995        filtered = gkv.filter(lambda kv: kv[0] == 1)
996        self.assertEqual(1, filtered.count())
997        self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect())
998        self.assertEqual([(N // 3, N // 3)],
999                         filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
1000        result = filtered.collect()[0][1]
1001        self.assertEqual(N // 3, len(result))
1002        self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList))
1003
1004    def test_sort_on_empty_rdd(self):
1005        self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
1006
1007    def test_sample(self):
1008        rdd = self.sc.parallelize(range(0, 100), 4)
1009        wo = rdd.sample(False, 0.1, 2).collect()
1010        wo_dup = rdd.sample(False, 0.1, 2).collect()
1011        self.assertSetEqual(set(wo), set(wo_dup))
1012        wr = rdd.sample(True, 0.2, 5).collect()
1013        wr_dup = rdd.sample(True, 0.2, 5).collect()
1014        self.assertSetEqual(set(wr), set(wr_dup))
1015        wo_s10 = rdd.sample(False, 0.3, 10).collect()
1016        wo_s20 = rdd.sample(False, 0.3, 20).collect()
1017        self.assertNotEqual(set(wo_s10), set(wo_s20))
1018        wr_s11 = rdd.sample(True, 0.4, 11).collect()
1019        wr_s21 = rdd.sample(True, 0.4, 21).collect()
1020        self.assertNotEqual(set(wr_s11), set(wr_s21))
1021
1022    def test_null_in_rdd(self):
1023        jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
1024        rdd = RDD(jrdd, self.sc, UTF8Deserializer())
1025        self.assertEqual([u"a", None, u"b"], rdd.collect())
1026        rdd = RDD(jrdd, self.sc, NoOpSerializer())
1027        self.assertEqual([b"a", None, b"b"], rdd.collect())
1028
1029    def test_multiple_python_java_RDD_conversions(self):
1030        # Regression test for SPARK-5361
1031        data = [
1032            (u'1', {u'director': u'David Lean'}),
1033            (u'2', {u'director': u'Andrew Dominik'})
1034        ]
1035        data_rdd = self.sc.parallelize(data)
1036        data_java_rdd = data_rdd._to_java_object_rdd()
1037        data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
1038        converted_rdd = RDD(data_python_rdd, self.sc)
1039        self.assertEqual(2, converted_rdd.count())
1040
1041        # conversion between python and java RDD threw exceptions
1042        data_java_rdd = converted_rdd._to_java_object_rdd()
1043        data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
1044        converted_rdd = RDD(data_python_rdd, self.sc)
1045        self.assertEqual(2, converted_rdd.count())
1046
1047    def test_narrow_dependency_in_join(self):
1048        rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
1049        parted = rdd.partitionBy(2)
1050        self.assertEqual(2, parted.union(parted).getNumPartitions())
1051        self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
1052        self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
1053
1054        tracker = self.sc.statusTracker()
1055
1056        self.sc.setJobGroup("test1", "test", True)
1057        d = sorted(parted.join(parted).collect())
1058        self.assertEqual(10, len(d))
1059        self.assertEqual((0, (0, 0)), d[0])
1060        jobId = tracker.getJobIdsForGroup("test1")[0]
1061        self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
1062
1063        self.sc.setJobGroup("test2", "test", True)
1064        d = sorted(parted.join(rdd).collect())
1065        self.assertEqual(10, len(d))
1066        self.assertEqual((0, (0, 0)), d[0])
1067        jobId = tracker.getJobIdsForGroup("test2")[0]
1068        self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
1069
1070        self.sc.setJobGroup("test3", "test", True)
1071        d = sorted(parted.cogroup(parted).collect())
1072        self.assertEqual(10, len(d))
1073        self.assertEqual([[0], [0]], list(map(list, d[0][1])))
1074        jobId = tracker.getJobIdsForGroup("test3")[0]
1075        self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
1076
1077        self.sc.setJobGroup("test4", "test", True)
1078        d = sorted(parted.cogroup(rdd).collect())
1079        self.assertEqual(10, len(d))
1080        self.assertEqual([[0], [0]], list(map(list, d[0][1])))
1081        jobId = tracker.getJobIdsForGroup("test4")[0]
1082        self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
1083
1084    # Regression test for SPARK-6294
1085    def test_take_on_jrdd(self):
1086        rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
1087        rdd._jrdd.first()
1088
1089    def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
1090        # Regression test for SPARK-5969
1091        seq = [(i * 59 % 101, i) for i in range(101)]  # unsorted sequence
1092        rdd = self.sc.parallelize(seq)
1093        for ascending in [True, False]:
1094            sort = rdd.sortByKey(ascending=ascending, numPartitions=5)
1095            self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending))
1096            sizes = sort.glom().map(len).collect()
1097            for size in sizes:
1098                self.assertGreater(size, 0)
1099
1100    def test_pipe_functions(self):
1101        data = ['1', '2', '3']
1102        rdd = self.sc.parallelize(data)
1103        with QuietTest(self.sc):
1104            self.assertEqual([], rdd.pipe('cc').collect())
1105            self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect)
1106        result = rdd.pipe('cat').collect()
1107        result.sort()
1108        for x, y in zip(data, result):
1109            self.assertEqual(x, y)
1110        self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
1111        self.assertEqual([], rdd.pipe('grep 4').collect())
1112
1113
1114class ProfilerTests(PySparkTestCase):
1115
1116    def setUp(self):
1117        self._old_sys_path = list(sys.path)
1118        class_name = self.__class__.__name__
1119        conf = SparkConf().set("spark.python.profile", "true")
1120        self.sc = SparkContext('local[4]', class_name, conf=conf)
1121
1122    def test_profiler(self):
1123        self.do_computation()
1124
1125        profilers = self.sc.profiler_collector.profilers
1126        self.assertEqual(1, len(profilers))
1127        id, profiler, _ = profilers[0]
1128        stats = profiler.stats()
1129        self.assertTrue(stats is not None)
1130        width, stat_list = stats.get_print_list([])
1131        func_names = [func_name for fname, n, func_name in stat_list]
1132        self.assertTrue("heavy_foo" in func_names)
1133
1134        old_stdout = sys.stdout
1135        sys.stdout = io = StringIO()
1136        self.sc.show_profiles()
1137        self.assertTrue("heavy_foo" in io.getvalue())
1138        sys.stdout = old_stdout
1139
1140        d = tempfile.gettempdir()
1141        self.sc.dump_profiles(d)
1142        self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
1143
1144    def test_custom_profiler(self):
1145        class TestCustomProfiler(BasicProfiler):
1146            def show(self, id):
1147                self.result = "Custom formatting"
1148
1149        self.sc.profiler_collector.profiler_cls = TestCustomProfiler
1150
1151        self.do_computation()
1152
1153        profilers = self.sc.profiler_collector.profilers
1154        self.assertEqual(1, len(profilers))
1155        _, profiler, _ = profilers[0]
1156        self.assertTrue(isinstance(profiler, TestCustomProfiler))
1157
1158        self.sc.show_profiles()
1159        self.assertEqual("Custom formatting", profiler.result)
1160
1161    def do_computation(self):
1162        def heavy_foo(x):
1163            for i in range(1 << 18):
1164                x = 1
1165
1166        rdd = self.sc.parallelize(range(100))
1167        rdd.foreach(heavy_foo)
1168
1169
1170class InputFormatTests(ReusedPySparkTestCase):
1171
1172    @classmethod
1173    def setUpClass(cls):
1174        ReusedPySparkTestCase.setUpClass()
1175        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
1176        os.unlink(cls.tempdir.name)
1177        cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)
1178
1179    @classmethod
1180    def tearDownClass(cls):
1181        ReusedPySparkTestCase.tearDownClass()
1182        shutil.rmtree(cls.tempdir.name)
1183
1184    @unittest.skipIf(sys.version >= "3", "serialize array of byte")
1185    def test_sequencefiles(self):
1186        basepath = self.tempdir.name
1187        ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/",
1188                                           "org.apache.hadoop.io.IntWritable",
1189                                           "org.apache.hadoop.io.Text").collect())
1190        ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
1191        self.assertEqual(ints, ei)
1192
1193        doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/",
1194                                              "org.apache.hadoop.io.DoubleWritable",
1195                                              "org.apache.hadoop.io.Text").collect())
1196        ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')]
1197        self.assertEqual(doubles, ed)
1198
1199        bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/",
1200                                            "org.apache.hadoop.io.IntWritable",
1201                                            "org.apache.hadoop.io.BytesWritable").collect())
1202        ebs = [(1, bytearray('aa', 'utf-8')),
1203               (1, bytearray('aa', 'utf-8')),
1204               (2, bytearray('aa', 'utf-8')),
1205               (2, bytearray('bb', 'utf-8')),
1206               (2, bytearray('bb', 'utf-8')),
1207               (3, bytearray('cc', 'utf-8'))]
1208        self.assertEqual(bytes, ebs)
1209
1210        text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/",
1211                                           "org.apache.hadoop.io.Text",
1212                                           "org.apache.hadoop.io.Text").collect())
1213        et = [(u'1', u'aa'),
1214              (u'1', u'aa'),
1215              (u'2', u'aa'),
1216              (u'2', u'bb'),
1217              (u'2', u'bb'),
1218              (u'3', u'cc')]
1219        self.assertEqual(text, et)
1220
1221        bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/",
1222                                            "org.apache.hadoop.io.IntWritable",
1223                                            "org.apache.hadoop.io.BooleanWritable").collect())
1224        eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)]
1225        self.assertEqual(bools, eb)
1226
1227        nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/",
1228                                            "org.apache.hadoop.io.IntWritable",
1229                                            "org.apache.hadoop.io.BooleanWritable").collect())
1230        en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
1231        self.assertEqual(nulls, en)
1232
1233        maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
1234                                    "org.apache.hadoop.io.IntWritable",
1235                                    "org.apache.hadoop.io.MapWritable").collect()
1236        em = [(1, {}),
1237              (1, {3.0: u'bb'}),
1238              (2, {1.0: u'aa'}),
1239              (2, {1.0: u'cc'}),
1240              (3, {2.0: u'dd'})]
1241        for v in maps:
1242            self.assertTrue(v in em)
1243
1244        # arrays get pickled to tuples by default
1245        tuples = sorted(self.sc.sequenceFile(
1246            basepath + "/sftestdata/sfarray/",
1247            "org.apache.hadoop.io.IntWritable",
1248            "org.apache.spark.api.python.DoubleArrayWritable").collect())
1249        et = [(1, ()),
1250              (2, (3.0, 4.0, 5.0)),
1251              (3, (4.0, 5.0, 6.0))]
1252        self.assertEqual(tuples, et)
1253
1254        # with custom converters, primitive arrays can stay as arrays
1255        arrays = sorted(self.sc.sequenceFile(
1256            basepath + "/sftestdata/sfarray/",
1257            "org.apache.hadoop.io.IntWritable",
1258            "org.apache.spark.api.python.DoubleArrayWritable",
1259            valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect())
1260        ea = [(1, array('d')),
1261              (2, array('d', [3.0, 4.0, 5.0])),
1262              (3, array('d', [4.0, 5.0, 6.0]))]
1263        self.assertEqual(arrays, ea)
1264
1265        clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
1266                                            "org.apache.hadoop.io.Text",
1267                                            "org.apache.spark.api.python.TestWritable").collect())
1268        cname = u'org.apache.spark.api.python.TestWritable'
1269        ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}),
1270              (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}),
1271              (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}),
1272              (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}),
1273              (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})]
1274        self.assertEqual(clazz, ec)
1275
1276        unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
1277                                                      "org.apache.hadoop.io.Text",
1278                                                      "org.apache.spark.api.python.TestWritable",
1279                                                      ).collect())
1280        self.assertEqual(unbatched_clazz, ec)
1281
1282    def test_oldhadoop(self):
1283        basepath = self.tempdir.name
1284        ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/",
1285                                         "org.apache.hadoop.mapred.SequenceFileInputFormat",
1286                                         "org.apache.hadoop.io.IntWritable",
1287                                         "org.apache.hadoop.io.Text").collect())
1288        ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
1289        self.assertEqual(ints, ei)
1290
1291        hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
1292        oldconf = {"mapred.input.dir": hellopath}
1293        hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat",
1294                                  "org.apache.hadoop.io.LongWritable",
1295                                  "org.apache.hadoop.io.Text",
1296                                  conf=oldconf).collect()
1297        result = [(0, u'Hello World!')]
1298        self.assertEqual(hello, result)
1299
1300    def test_newhadoop(self):
1301        basepath = self.tempdir.name
1302        ints = sorted(self.sc.newAPIHadoopFile(
1303            basepath + "/sftestdata/sfint/",
1304            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
1305            "org.apache.hadoop.io.IntWritable",
1306            "org.apache.hadoop.io.Text").collect())
1307        ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
1308        self.assertEqual(ints, ei)
1309
1310        hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
1311        newconf = {"mapred.input.dir": hellopath}
1312        hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat",
1313                                        "org.apache.hadoop.io.LongWritable",
1314                                        "org.apache.hadoop.io.Text",
1315                                        conf=newconf).collect()
1316        result = [(0, u'Hello World!')]
1317        self.assertEqual(hello, result)
1318
1319    def test_newolderror(self):
1320        basepath = self.tempdir.name
1321        self.assertRaises(Exception, lambda: self.sc.hadoopFile(
1322            basepath + "/sftestdata/sfint/",
1323            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
1324            "org.apache.hadoop.io.IntWritable",
1325            "org.apache.hadoop.io.Text"))
1326
1327        self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
1328            basepath + "/sftestdata/sfint/",
1329            "org.apache.hadoop.mapred.SequenceFileInputFormat",
1330            "org.apache.hadoop.io.IntWritable",
1331            "org.apache.hadoop.io.Text"))
1332
1333    def test_bad_inputs(self):
1334        basepath = self.tempdir.name
1335        self.assertRaises(Exception, lambda: self.sc.sequenceFile(
1336            basepath + "/sftestdata/sfint/",
1337            "org.apache.hadoop.io.NotValidWritable",
1338            "org.apache.hadoop.io.Text"))
1339        self.assertRaises(Exception, lambda: self.sc.hadoopFile(
1340            basepath + "/sftestdata/sfint/",
1341            "org.apache.hadoop.mapred.NotValidInputFormat",
1342            "org.apache.hadoop.io.IntWritable",
1343            "org.apache.hadoop.io.Text"))
1344        self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
1345            basepath + "/sftestdata/sfint/",
1346            "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat",
1347            "org.apache.hadoop.io.IntWritable",
1348            "org.apache.hadoop.io.Text"))
1349
1350    def test_converters(self):
1351        # use of custom converters
1352        basepath = self.tempdir.name
1353        maps = sorted(self.sc.sequenceFile(
1354            basepath + "/sftestdata/sfmap/",
1355            "org.apache.hadoop.io.IntWritable",
1356            "org.apache.hadoop.io.MapWritable",
1357            keyConverter="org.apache.spark.api.python.TestInputKeyConverter",
1358            valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect())
1359        em = [(u'\x01', []),
1360              (u'\x01', [3.0]),
1361              (u'\x02', [1.0]),
1362              (u'\x02', [1.0]),
1363              (u'\x03', [2.0])]
1364        self.assertEqual(maps, em)
1365
1366    def test_binary_files(self):
1367        path = os.path.join(self.tempdir.name, "binaryfiles")
1368        os.mkdir(path)
1369        data = b"short binary data"
1370        with open(os.path.join(path, "part-0000"), 'wb') as f:
1371            f.write(data)
1372        [(p, d)] = self.sc.binaryFiles(path).collect()
1373        self.assertTrue(p.endswith("part-0000"))
1374        self.assertEqual(d, data)
1375
1376    def test_binary_records(self):
1377        path = os.path.join(self.tempdir.name, "binaryrecords")
1378        os.mkdir(path)
1379        with open(os.path.join(path, "part-0000"), 'w') as f:
1380            for i in range(100):
1381                f.write('%04d' % i)
1382        result = self.sc.binaryRecords(path, 4).map(int).collect()
1383        self.assertEqual(list(range(100)), result)
1384
1385
1386class OutputFormatTests(ReusedPySparkTestCase):
1387
1388    def setUp(self):
1389        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
1390        os.unlink(self.tempdir.name)
1391
1392    def tearDown(self):
1393        shutil.rmtree(self.tempdir.name, ignore_errors=True)
1394
1395    @unittest.skipIf(sys.version >= "3", "serialize array of byte")
1396    def test_sequencefiles(self):
1397        basepath = self.tempdir.name
1398        ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
1399        self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/")
1400        ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect())
1401        self.assertEqual(ints, ei)
1402
1403        ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')]
1404        self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/")
1405        doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect())
1406        self.assertEqual(doubles, ed)
1407
1408        ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))]
1409        self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/")
1410        bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect())
1411        self.assertEqual(bytes, ebs)
1412
1413        et = [(u'1', u'aa'),
1414              (u'2', u'bb'),
1415              (u'3', u'cc')]
1416        self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/")
1417        text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect())
1418        self.assertEqual(text, et)
1419
1420        eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)]
1421        self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/")
1422        bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect())
1423        self.assertEqual(bools, eb)
1424
1425        en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
1426        self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/")
1427        nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect())
1428        self.assertEqual(nulls, en)
1429
1430        em = [(1, {}),
1431              (1, {3.0: u'bb'}),
1432              (2, {1.0: u'aa'}),
1433              (2, {1.0: u'cc'}),
1434              (3, {2.0: u'dd'})]
1435        self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/")
1436        maps = self.sc.sequenceFile(basepath + "/sfmap/").collect()
1437        for v in maps:
1438            self.assertTrue(v, em)
1439
1440    def test_oldhadoop(self):
1441        basepath = self.tempdir.name
1442        dict_data = [(1, {}),
1443                     (1, {"row1": 1.0}),
1444                     (2, {"row2": 2.0})]
1445        self.sc.parallelize(dict_data).saveAsHadoopFile(
1446            basepath + "/oldhadoop/",
1447            "org.apache.hadoop.mapred.SequenceFileOutputFormat",
1448            "org.apache.hadoop.io.IntWritable",
1449            "org.apache.hadoop.io.MapWritable")
1450        result = self.sc.hadoopFile(
1451            basepath + "/oldhadoop/",
1452            "org.apache.hadoop.mapred.SequenceFileInputFormat",
1453            "org.apache.hadoop.io.IntWritable",
1454            "org.apache.hadoop.io.MapWritable").collect()
1455        for v in result:
1456            self.assertTrue(v, dict_data)
1457
1458        conf = {
1459            "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat",
1460            "mapred.output.key.class": "org.apache.hadoop.io.IntWritable",
1461            "mapred.output.value.class": "org.apache.hadoop.io.MapWritable",
1462            "mapred.output.dir": basepath + "/olddataset/"
1463        }
1464        self.sc.parallelize(dict_data).saveAsHadoopDataset(conf)
1465        input_conf = {"mapred.input.dir": basepath + "/olddataset/"}
1466        result = self.sc.hadoopRDD(
1467            "org.apache.hadoop.mapred.SequenceFileInputFormat",
1468            "org.apache.hadoop.io.IntWritable",
1469            "org.apache.hadoop.io.MapWritable",
1470            conf=input_conf).collect()
1471        for v in result:
1472            self.assertTrue(v, dict_data)
1473
1474    def test_newhadoop(self):
1475        basepath = self.tempdir.name
1476        data = [(1, ""),
1477                (1, "a"),
1478                (2, "bcdf")]
1479        self.sc.parallelize(data).saveAsNewAPIHadoopFile(
1480            basepath + "/newhadoop/",
1481            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
1482            "org.apache.hadoop.io.IntWritable",
1483            "org.apache.hadoop.io.Text")
1484        result = sorted(self.sc.newAPIHadoopFile(
1485            basepath + "/newhadoop/",
1486            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
1487            "org.apache.hadoop.io.IntWritable",
1488            "org.apache.hadoop.io.Text").collect())
1489        self.assertEqual(result, data)
1490
1491        conf = {
1492            "mapreduce.outputformat.class":
1493                "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
1494            "mapred.output.key.class": "org.apache.hadoop.io.IntWritable",
1495            "mapred.output.value.class": "org.apache.hadoop.io.Text",
1496            "mapred.output.dir": basepath + "/newdataset/"
1497        }
1498        self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf)
1499        input_conf = {"mapred.input.dir": basepath + "/newdataset/"}
1500        new_dataset = sorted(self.sc.newAPIHadoopRDD(
1501            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
1502            "org.apache.hadoop.io.IntWritable",
1503            "org.apache.hadoop.io.Text",
1504            conf=input_conf).collect())
1505        self.assertEqual(new_dataset, data)
1506
1507    @unittest.skipIf(sys.version >= "3", "serialize of array")
1508    def test_newhadoop_with_array(self):
1509        basepath = self.tempdir.name
1510        # use custom ArrayWritable types and converters to handle arrays
1511        array_data = [(1, array('d')),
1512                      (1, array('d', [1.0, 2.0, 3.0])),
1513                      (2, array('d', [3.0, 4.0, 5.0]))]
1514        self.sc.parallelize(array_data).saveAsNewAPIHadoopFile(
1515            basepath + "/newhadoop/",
1516            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
1517            "org.apache.hadoop.io.IntWritable",
1518            "org.apache.spark.api.python.DoubleArrayWritable",
1519            valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter")
1520        result = sorted(self.sc.newAPIHadoopFile(
1521            basepath + "/newhadoop/",
1522            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
1523            "org.apache.hadoop.io.IntWritable",
1524            "org.apache.spark.api.python.DoubleArrayWritable",
1525            valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect())
1526        self.assertEqual(result, array_data)
1527
1528        conf = {
1529            "mapreduce.outputformat.class":
1530                "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
1531            "mapred.output.key.class": "org.apache.hadoop.io.IntWritable",
1532            "mapred.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable",
1533            "mapred.output.dir": basepath + "/newdataset/"
1534        }
1535        self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(
1536            conf,
1537            valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter")
1538        input_conf = {"mapred.input.dir": basepath + "/newdataset/"}
1539        new_dataset = sorted(self.sc.newAPIHadoopRDD(
1540            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
1541            "org.apache.hadoop.io.IntWritable",
1542            "org.apache.spark.api.python.DoubleArrayWritable",
1543            valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter",
1544            conf=input_conf).collect())
1545        self.assertEqual(new_dataset, array_data)
1546
1547    def test_newolderror(self):
1548        basepath = self.tempdir.name
1549        rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x))
1550        self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile(
1551            basepath + "/newolderror/saveAsHadoopFile/",
1552            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat"))
1553        self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile(
1554            basepath + "/newolderror/saveAsNewAPIHadoopFile/",
1555            "org.apache.hadoop.mapred.SequenceFileOutputFormat"))
1556
1557    def test_bad_inputs(self):
1558        basepath = self.tempdir.name
1559        rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x))
1560        self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile(
1561            basepath + "/badinputs/saveAsHadoopFile/",
1562            "org.apache.hadoop.mapred.NotValidOutputFormat"))
1563        self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile(
1564            basepath + "/badinputs/saveAsNewAPIHadoopFile/",
1565            "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat"))
1566
1567    def test_converters(self):
1568        # use of custom converters
1569        basepath = self.tempdir.name
1570        data = [(1, {3.0: u'bb'}),
1571                (2, {1.0: u'aa'}),
1572                (3, {2.0: u'dd'})]
1573        self.sc.parallelize(data).saveAsNewAPIHadoopFile(
1574            basepath + "/converters/",
1575            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
1576            keyConverter="org.apache.spark.api.python.TestOutputKeyConverter",
1577            valueConverter="org.apache.spark.api.python.TestOutputValueConverter")
1578        converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect())
1579        expected = [(u'1', 3.0),
1580                    (u'2', 1.0),
1581                    (u'3', 2.0)]
1582        self.assertEqual(converted, expected)
1583
1584    def test_reserialization(self):
1585        basepath = self.tempdir.name
1586        x = range(1, 5)
1587        y = range(1001, 1005)
1588        data = list(zip(x, y))
1589        rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y))
1590        rdd.saveAsSequenceFile(basepath + "/reserialize/sequence")
1591        result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect())
1592        self.assertEqual(result1, data)
1593
1594        rdd.saveAsHadoopFile(
1595            basepath + "/reserialize/hadoop",
1596            "org.apache.hadoop.mapred.SequenceFileOutputFormat")
1597        result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect())
1598        self.assertEqual(result2, data)
1599
1600        rdd.saveAsNewAPIHadoopFile(
1601            basepath + "/reserialize/newhadoop",
1602            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")
1603        result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect())
1604        self.assertEqual(result3, data)
1605
1606        conf4 = {
1607            "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat",
1608            "mapred.output.key.class": "org.apache.hadoop.io.IntWritable",
1609            "mapred.output.value.class": "org.apache.hadoop.io.IntWritable",
1610            "mapred.output.dir": basepath + "/reserialize/dataset"}
1611        rdd.saveAsHadoopDataset(conf4)
1612        result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect())
1613        self.assertEqual(result4, data)
1614
1615        conf5 = {"mapreduce.outputformat.class":
1616                 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
1617                 "mapred.output.key.class": "org.apache.hadoop.io.IntWritable",
1618                 "mapred.output.value.class": "org.apache.hadoop.io.IntWritable",
1619                 "mapred.output.dir": basepath + "/reserialize/newdataset"}
1620        rdd.saveAsNewAPIHadoopDataset(conf5)
1621        result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect())
1622        self.assertEqual(result5, data)
1623
1624    def test_malformed_RDD(self):
1625        basepath = self.tempdir.name
1626        # non-batch-serialized RDD[[(K, V)]] should be rejected
1627        data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]]
1628        rdd = self.sc.parallelize(data, len(data))
1629        self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile(
1630            basepath + "/malformed/sequence"))
1631
1632
1633class DaemonTests(unittest.TestCase):
1634    def connect(self, port):
1635        from socket import socket, AF_INET, SOCK_STREAM
1636        sock = socket(AF_INET, SOCK_STREAM)
1637        sock.connect(('127.0.0.1', port))
1638        # send a split index of -1 to shutdown the worker
1639        sock.send(b"\xFF\xFF\xFF\xFF")
1640        sock.close()
1641        return True
1642
1643    def do_termination_test(self, terminator):
1644        from subprocess import Popen, PIPE
1645        from errno import ECONNREFUSED
1646
1647        # start daemon
1648        daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
1649        python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON")
1650        daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE)
1651
1652        # read the port number
1653        port = read_int(daemon.stdout)
1654
1655        # daemon should accept connections
1656        self.assertTrue(self.connect(port))
1657
1658        # request shutdown
1659        terminator(daemon)
1660        time.sleep(1)
1661
1662        # daemon should no longer accept connections
1663        try:
1664            self.connect(port)
1665        except EnvironmentError as exception:
1666            self.assertEqual(exception.errno, ECONNREFUSED)
1667        else:
1668            self.fail("Expected EnvironmentError to be raised")
1669
1670    def test_termination_stdin(self):
1671        """Ensure that daemon and workers terminate when stdin is closed."""
1672        self.do_termination_test(lambda daemon: daemon.stdin.close())
1673
1674    def test_termination_sigterm(self):
1675        """Ensure that daemon and workers terminate on SIGTERM."""
1676        from signal import SIGTERM
1677        self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
1678
1679
1680class WorkerTests(ReusedPySparkTestCase):
1681    def test_cancel_task(self):
1682        temp = tempfile.NamedTemporaryFile(delete=True)
1683        temp.close()
1684        path = temp.name
1685
1686        def sleep(x):
1687            import os
1688            import time
1689            with open(path, 'w') as f:
1690                f.write("%d %d" % (os.getppid(), os.getpid()))
1691            time.sleep(100)
1692
1693        # start job in background thread
1694        def run():
1695            try:
1696                self.sc.parallelize(range(1), 1).foreach(sleep)
1697            except Exception:
1698                pass
1699        import threading
1700        t = threading.Thread(target=run)
1701        t.daemon = True
1702        t.start()
1703
1704        daemon_pid, worker_pid = 0, 0
1705        while True:
1706            if os.path.exists(path):
1707                with open(path) as f:
1708                    data = f.read().split(' ')
1709                daemon_pid, worker_pid = map(int, data)
1710                break
1711            time.sleep(0.1)
1712
1713        # cancel jobs
1714        self.sc.cancelAllJobs()
1715        t.join()
1716
1717        for i in range(50):
1718            try:
1719                os.kill(worker_pid, 0)
1720                time.sleep(0.1)
1721            except OSError:
1722                break  # worker was killed
1723        else:
1724            self.fail("worker has not been killed after 5 seconds")
1725
1726        try:
1727            os.kill(daemon_pid, 0)
1728        except OSError:
1729            self.fail("daemon had been killed")
1730
1731        # run a normal job
1732        rdd = self.sc.parallelize(xrange(100), 1)
1733        self.assertEqual(100, rdd.map(str).count())
1734
1735    def test_after_exception(self):
1736        def raise_exception(_):
1737            raise Exception()
1738        rdd = self.sc.parallelize(xrange(100), 1)
1739        with QuietTest(self.sc):
1740            self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
1741        self.assertEqual(100, rdd.map(str).count())
1742
1743    def test_after_jvm_exception(self):
1744        tempFile = tempfile.NamedTemporaryFile(delete=False)
1745        tempFile.write(b"Hello World!")
1746        tempFile.close()
1747        data = self.sc.textFile(tempFile.name, 1)
1748        filtered_data = data.filter(lambda x: True)
1749        self.assertEqual(1, filtered_data.count())
1750        os.unlink(tempFile.name)
1751        with QuietTest(self.sc):
1752            self.assertRaises(Exception, lambda: filtered_data.count())
1753
1754        rdd = self.sc.parallelize(xrange(100), 1)
1755        self.assertEqual(100, rdd.map(str).count())
1756
1757    def test_accumulator_when_reuse_worker(self):
1758        from pyspark.accumulators import INT_ACCUMULATOR_PARAM
1759        acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
1760        self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
1761        self.assertEqual(sum(range(100)), acc1.value)
1762
1763        acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
1764        self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
1765        self.assertEqual(sum(range(100)), acc2.value)
1766        self.assertEqual(sum(range(100)), acc1.value)
1767
1768    def test_reuse_worker_after_take(self):
1769        rdd = self.sc.parallelize(xrange(100000), 1)
1770        self.assertEqual(0, rdd.first())
1771
1772        def count():
1773            try:
1774                rdd.count()
1775            except Exception:
1776                pass
1777
1778        t = threading.Thread(target=count)
1779        t.daemon = True
1780        t.start()
1781        t.join(5)
1782        self.assertTrue(not t.isAlive())
1783        self.assertEqual(100000, rdd.count())
1784
1785    def test_with_different_versions_of_python(self):
1786        rdd = self.sc.parallelize(range(10))
1787        rdd.count()
1788        version = self.sc.pythonVer
1789        self.sc.pythonVer = "2.0"
1790        try:
1791            with QuietTest(self.sc):
1792                self.assertRaises(Py4JJavaError, lambda: rdd.count())
1793        finally:
1794            self.sc.pythonVer = version
1795
1796
1797class SparkSubmitTests(unittest.TestCase):
1798
1799    def setUp(self):
1800        self.programDir = tempfile.mkdtemp()
1801        self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit")
1802
1803    def tearDown(self):
1804        shutil.rmtree(self.programDir)
1805
1806    def createTempFile(self, name, content, dir=None):
1807        """
1808        Create a temp file with the given name and content and return its path.
1809        Strips leading spaces from content up to the first '|' in each line.
1810        """
1811        pattern = re.compile(r'^ *\|', re.MULTILINE)
1812        content = re.sub(pattern, '', content.strip())
1813        if dir is None:
1814            path = os.path.join(self.programDir, name)
1815        else:
1816            os.makedirs(os.path.join(self.programDir, dir))
1817            path = os.path.join(self.programDir, dir, name)
1818        with open(path, "w") as f:
1819            f.write(content)
1820        return path
1821
1822    def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None):
1823        """
1824        Create a zip archive containing a file with the given content and return its path.
1825        Strips leading spaces from content up to the first '|' in each line.
1826        """
1827        pattern = re.compile(r'^ *\|', re.MULTILINE)
1828        content = re.sub(pattern, '', content.strip())
1829        if dir is None:
1830            path = os.path.join(self.programDir, name + ext)
1831        else:
1832            path = os.path.join(self.programDir, dir, zip_name + ext)
1833        zip = zipfile.ZipFile(path, 'w')
1834        zip.writestr(name, content)
1835        zip.close()
1836        return path
1837
1838    def create_spark_package(self, artifact_name):
1839        group_id, artifact_id, version = artifact_name.split(":")
1840        self.createTempFile("%s-%s.pom" % (artifact_id, version), ("""
1841            |<?xml version="1.0" encoding="UTF-8"?>
1842            |<project xmlns="http://maven.apache.org/POM/4.0.0"
1843            |       xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
1844            |       xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
1845            |       http://maven.apache.org/xsd/maven-4.0.0.xsd">
1846            |   <modelVersion>4.0.0</modelVersion>
1847            |   <groupId>%s</groupId>
1848            |   <artifactId>%s</artifactId>
1849            |   <version>%s</version>
1850            |</project>
1851            """ % (group_id, artifact_id, version)).lstrip(),
1852            os.path.join(group_id, artifact_id, version))
1853        self.createFileInZip("%s.py" % artifact_id, """
1854            |def myfunc(x):
1855            |    return x + 1
1856            """, ".jar", os.path.join(group_id, artifact_id, version),
1857                             "%s-%s" % (artifact_id, version))
1858
1859    def test_single_script(self):
1860        """Submit and test a single script file"""
1861        script = self.createTempFile("test.py", """
1862            |from pyspark import SparkContext
1863            |
1864            |sc = SparkContext()
1865            |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect())
1866            """)
1867        proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE)
1868        out, err = proc.communicate()
1869        self.assertEqual(0, proc.returncode)
1870        self.assertIn("[2, 4, 6]", out.decode('utf-8'))
1871
1872    def test_script_with_local_functions(self):
1873        """Submit and test a single script file calling a global function"""
1874        script = self.createTempFile("test.py", """
1875            |from pyspark import SparkContext
1876            |
1877            |def foo(x):
1878            |    return x * 3
1879            |
1880            |sc = SparkContext()
1881            |print(sc.parallelize([1, 2, 3]).map(foo).collect())
1882            """)
1883        proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE)
1884        out, err = proc.communicate()
1885        self.assertEqual(0, proc.returncode)
1886        self.assertIn("[3, 6, 9]", out.decode('utf-8'))
1887
1888    def test_module_dependency(self):
1889        """Submit and test a script with a dependency on another module"""
1890        script = self.createTempFile("test.py", """
1891            |from pyspark import SparkContext
1892            |from mylib import myfunc
1893            |
1894            |sc = SparkContext()
1895            |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
1896            """)
1897        zip = self.createFileInZip("mylib.py", """
1898            |def myfunc(x):
1899            |    return x + 1
1900            """)
1901        proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script],
1902                                stdout=subprocess.PIPE)
1903        out, err = proc.communicate()
1904        self.assertEqual(0, proc.returncode)
1905        self.assertIn("[2, 3, 4]", out.decode('utf-8'))
1906
1907    def test_module_dependency_on_cluster(self):
1908        """Submit and test a script with a dependency on another module on a cluster"""
1909        script = self.createTempFile("test.py", """
1910            |from pyspark import SparkContext
1911            |from mylib import myfunc
1912            |
1913            |sc = SparkContext()
1914            |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
1915            """)
1916        zip = self.createFileInZip("mylib.py", """
1917            |def myfunc(x):
1918            |    return x + 1
1919            """)
1920        proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master",
1921                                "local-cluster[1,1,1024]", script],
1922                                stdout=subprocess.PIPE)
1923        out, err = proc.communicate()
1924        self.assertEqual(0, proc.returncode)
1925        self.assertIn("[2, 3, 4]", out.decode('utf-8'))
1926
1927    def test_package_dependency(self):
1928        """Submit and test a script with a dependency on a Spark Package"""
1929        script = self.createTempFile("test.py", """
1930            |from pyspark import SparkContext
1931            |from mylib import myfunc
1932            |
1933            |sc = SparkContext()
1934            |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
1935            """)
1936        self.create_spark_package("a:mylib:0.1")
1937        proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
1938                                 "file:" + self.programDir, script], stdout=subprocess.PIPE)
1939        out, err = proc.communicate()
1940        self.assertEqual(0, proc.returncode)
1941        self.assertIn("[2, 3, 4]", out.decode('utf-8'))
1942
1943    def test_package_dependency_on_cluster(self):
1944        """Submit and test a script with a dependency on a Spark Package on a cluster"""
1945        script = self.createTempFile("test.py", """
1946            |from pyspark import SparkContext
1947            |from mylib import myfunc
1948            |
1949            |sc = SparkContext()
1950            |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
1951            """)
1952        self.create_spark_package("a:mylib:0.1")
1953        proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
1954                                 "file:" + self.programDir, "--master",
1955                                 "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE)
1956        out, err = proc.communicate()
1957        self.assertEqual(0, proc.returncode)
1958        self.assertIn("[2, 3, 4]", out.decode('utf-8'))
1959
1960    def test_single_script_on_cluster(self):
1961        """Submit and test a single script on a cluster"""
1962        script = self.createTempFile("test.py", """
1963            |from pyspark import SparkContext
1964            |
1965            |def foo(x):
1966            |    return x * 2
1967            |
1968            |sc = SparkContext()
1969            |print(sc.parallelize([1, 2, 3]).map(foo).collect())
1970            """)
1971        # this will fail if you have different spark.executor.memory
1972        # in conf/spark-defaults.conf
1973        proc = subprocess.Popen(
1974            [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script],
1975            stdout=subprocess.PIPE)
1976        out, err = proc.communicate()
1977        self.assertEqual(0, proc.returncode)
1978        self.assertIn("[2, 4, 6]", out.decode('utf-8'))
1979
1980    def test_user_configuration(self):
1981        """Make sure user configuration is respected (SPARK-19307)"""
1982        script = self.createTempFile("test.py", """
1983            |from pyspark import SparkConf, SparkContext
1984            |
1985            |conf = SparkConf().set("spark.test_config", "1")
1986            |sc = SparkContext(conf = conf)
1987            |try:
1988            |    if sc._conf.get("spark.test_config") != "1":
1989            |        raise Exception("Cannot find spark.test_config in SparkContext's conf.")
1990            |finally:
1991            |    sc.stop()
1992            """)
1993        proc = subprocess.Popen(
1994            [self.sparkSubmit, "--master", "local", script],
1995            stdout=subprocess.PIPE,
1996            stderr=subprocess.STDOUT)
1997        out, err = proc.communicate()
1998        self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out))
1999
2000
2001class ContextTests(unittest.TestCase):
2002
2003    def test_failed_sparkcontext_creation(self):
2004        # Regression test for SPARK-1550
2005        self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
2006
2007    def test_get_or_create(self):
2008        with SparkContext.getOrCreate() as sc:
2009            self.assertTrue(SparkContext.getOrCreate() is sc)
2010
2011    def test_parallelize_eager_cleanup(self):
2012        with SparkContext() as sc:
2013            temp_files = os.listdir(sc._temp_dir)
2014            rdd = sc.parallelize([0, 1, 2])
2015            post_parallalize_temp_files = os.listdir(sc._temp_dir)
2016            self.assertEqual(temp_files, post_parallalize_temp_files)
2017
2018    def test_set_conf(self):
2019        # This is for an internal use case. When there is an existing SparkContext,
2020        # SparkSession's builder needs to set configs into SparkContext's conf.
2021        sc = SparkContext()
2022        sc._conf.set("spark.test.SPARK16224", "SPARK16224")
2023        self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224")
2024        sc.stop()
2025
2026    def test_stop(self):
2027        sc = SparkContext()
2028        self.assertNotEqual(SparkContext._active_spark_context, None)
2029        sc.stop()
2030        self.assertEqual(SparkContext._active_spark_context, None)
2031
2032    def test_with(self):
2033        with SparkContext() as sc:
2034            self.assertNotEqual(SparkContext._active_spark_context, None)
2035        self.assertEqual(SparkContext._active_spark_context, None)
2036
2037    def test_with_exception(self):
2038        try:
2039            with SparkContext() as sc:
2040                self.assertNotEqual(SparkContext._active_spark_context, None)
2041                raise Exception()
2042        except:
2043            pass
2044        self.assertEqual(SparkContext._active_spark_context, None)
2045
2046    def test_with_stop(self):
2047        with SparkContext() as sc:
2048            self.assertNotEqual(SparkContext._active_spark_context, None)
2049            sc.stop()
2050        self.assertEqual(SparkContext._active_spark_context, None)
2051
2052    def test_progress_api(self):
2053        with SparkContext() as sc:
2054            sc.setJobGroup('test_progress_api', '', True)
2055            rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
2056
2057            def run():
2058                try:
2059                    rdd.count()
2060                except Exception:
2061                    pass
2062            t = threading.Thread(target=run)
2063            t.daemon = True
2064            t.start()
2065            # wait for scheduler to start
2066            time.sleep(1)
2067
2068            tracker = sc.statusTracker()
2069            jobIds = tracker.getJobIdsForGroup('test_progress_api')
2070            self.assertEqual(1, len(jobIds))
2071            job = tracker.getJobInfo(jobIds[0])
2072            self.assertEqual(1, len(job.stageIds))
2073            stage = tracker.getStageInfo(job.stageIds[0])
2074            self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
2075
2076            sc.cancelAllJobs()
2077            t.join()
2078            # wait for event listener to update the status
2079            time.sleep(1)
2080
2081            job = tracker.getJobInfo(jobIds[0])
2082            self.assertEqual('FAILED', job.status)
2083            self.assertEqual([], tracker.getActiveJobsIds())
2084            self.assertEqual([], tracker.getActiveStageIds())
2085
2086            sc.stop()
2087
2088    def test_startTime(self):
2089        with SparkContext() as sc:
2090            self.assertGreater(sc.startTime, 0)
2091
2092
2093class ConfTests(unittest.TestCase):
2094    def test_memory_conf(self):
2095        memoryList = ["1T", "1G", "1M", "1024K"]
2096        for memory in memoryList:
2097            sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory))
2098            l = list(range(1024))
2099            random.shuffle(l)
2100            rdd = sc.parallelize(l, 4)
2101            self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
2102            sc.stop()
2103
2104
2105class KeywordOnlyTests(unittest.TestCase):
2106    class Wrapped(object):
2107        @keyword_only
2108        def set(self, x=None, y=None):
2109            if "x" in self._input_kwargs:
2110                self._x = self._input_kwargs["x"]
2111            if "y" in self._input_kwargs:
2112                self._y = self._input_kwargs["y"]
2113            return x, y
2114
2115    def test_keywords(self):
2116        w = self.Wrapped()
2117        x, y = w.set(y=1)
2118        self.assertEqual(y, 1)
2119        self.assertEqual(y, w._y)
2120        self.assertIsNone(x)
2121        self.assertFalse(hasattr(w, "_x"))
2122
2123    def test_non_keywords(self):
2124        w = self.Wrapped()
2125        self.assertRaises(TypeError, lambda: w.set(0, y=1))
2126
2127    def test_kwarg_ownership(self):
2128        # test _input_kwargs is owned by each class instance and not a shared static variable
2129        class Setter(object):
2130            @keyword_only
2131            def set(self, x=None, other=None, other_x=None):
2132                if "other" in self._input_kwargs:
2133                    self._input_kwargs["other"].set(x=self._input_kwargs["other_x"])
2134                self._x = self._input_kwargs["x"]
2135
2136        a = Setter()
2137        b = Setter()
2138        a.set(x=1, other=b, other_x=2)
2139        self.assertEqual(a._x, 1)
2140        self.assertEqual(b._x, 2)
2141
2142
2143@unittest.skipIf(not _have_scipy, "SciPy not installed")
2144class SciPyTests(PySparkTestCase):
2145
2146    """General PySpark tests that depend on scipy """
2147
2148    def test_serialize(self):
2149        from scipy.special import gammaln
2150        x = range(1, 5)
2151        expected = list(map(gammaln, x))
2152        observed = self.sc.parallelize(x).map(gammaln).collect()
2153        self.assertEqual(expected, observed)
2154
2155
2156@unittest.skipIf(not _have_numpy, "NumPy not installed")
2157class NumPyTests(PySparkTestCase):
2158
2159    """General PySpark tests that depend on numpy """
2160
2161    def test_statcounter_array(self):
2162        x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])])
2163        s = x.stats()
2164        self.assertSequenceEqual([2.0, 2.0], s.mean().tolist())
2165        self.assertSequenceEqual([1.0, 1.0], s.min().tolist())
2166        self.assertSequenceEqual([3.0, 3.0], s.max().tolist())
2167        self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist())
2168
2169        stats_dict = s.asDict()
2170        self.assertEqual(3, stats_dict['count'])
2171        self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist())
2172        self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist())
2173        self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist())
2174        self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist())
2175        self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist())
2176        self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist())
2177
2178        stats_sample_dict = s.asDict(sample=True)
2179        self.assertEqual(3, stats_dict['count'])
2180        self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist())
2181        self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist())
2182        self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist())
2183        self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist())
2184        self.assertSequenceEqual(
2185            [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist())
2186        self.assertSequenceEqual(
2187            [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist())
2188
2189
2190if __name__ == "__main__":
2191    from pyspark.tests import *
2192    if not _have_scipy:
2193        print("NOTE: Skipping SciPy tests as it does not seem to be installed")
2194    if not _have_numpy:
2195        print("NOTE: Skipping NumPy tests as it does not seem to be installed")
2196    if xmlrunner:
2197        unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
2198    else:
2199        unittest.main()
2200    if not _have_scipy:
2201        print("NOTE: SciPy tests were skipped as it does not seem to be installed")
2202    if not _have_numpy:
2203        print("NOTE: NumPy tests were skipped as it does not seem to be installed")
2204