1# 2# Licensed to the Apache Software Foundation (ASF) under one or more 3# contributor license agreements. See the NOTICE file distributed with 4# this work for additional information regarding copyright ownership. 5# The ASF licenses this file to You under the Apache License, Version 2.0 6# (the "License"); you may not use this file except in compliance with 7# the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18""" 19Unit tests for PySpark; additional tests are implemented as doctests in 20individual modules. 21""" 22 23from array import array 24from glob import glob 25import os 26import re 27import shutil 28import subprocess 29import sys 30import tempfile 31import time 32import zipfile 33import random 34import threading 35import hashlib 36 37from py4j.protocol import Py4JJavaError 38try: 39 import xmlrunner 40except ImportError: 41 xmlrunner = None 42 43if sys.version_info[:2] <= (2, 6): 44 try: 45 import unittest2 as unittest 46 except ImportError: 47 sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') 48 sys.exit(1) 49else: 50 import unittest 51 if sys.version_info[0] >= 3: 52 xrange = range 53 basestring = str 54 55if sys.version >= "3": 56 from io import StringIO 57else: 58 from StringIO import StringIO 59 60 61from pyspark import keyword_only 62from pyspark.conf import SparkConf 63from pyspark.context import SparkContext 64from pyspark.rdd import RDD 65from pyspark.files import SparkFiles 66from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ 67 CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ 68 PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ 69 FlattenedValuesSerializer 70from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter 71from pyspark import shuffle 72from pyspark.profiler import BasicProfiler 73 74_have_scipy = False 75_have_numpy = False 76try: 77 import scipy.sparse 78 _have_scipy = True 79except: 80 # No SciPy, but that's okay, we'll skip those tests 81 pass 82try: 83 import numpy as np 84 _have_numpy = True 85except: 86 # No NumPy, but that's okay, we'll skip those tests 87 pass 88 89 90SPARK_HOME = os.environ["SPARK_HOME"] 91 92 93class MergerTests(unittest.TestCase): 94 95 def setUp(self): 96 self.N = 1 << 12 97 self.l = [i for i in xrange(self.N)] 98 self.data = list(zip(self.l, self.l)) 99 self.agg = Aggregator(lambda x: [x], 100 lambda x, y: x.append(y) or x, 101 lambda x, y: x.extend(y) or x) 102 103 def test_small_dataset(self): 104 m = ExternalMerger(self.agg, 1000) 105 m.mergeValues(self.data) 106 self.assertEqual(m.spills, 0) 107 self.assertEqual(sum(sum(v) for k, v in m.items()), 108 sum(xrange(self.N))) 109 110 m = ExternalMerger(self.agg, 1000) 111 m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) 112 self.assertEqual(m.spills, 0) 113 self.assertEqual(sum(sum(v) for k, v in m.items()), 114 sum(xrange(self.N))) 115 116 def test_medium_dataset(self): 117 m = ExternalMerger(self.agg, 20) 118 m.mergeValues(self.data) 119 self.assertTrue(m.spills >= 1) 120 self.assertEqual(sum(sum(v) for k, v in m.items()), 121 sum(xrange(self.N))) 122 123 m = ExternalMerger(self.agg, 10) 124 m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) 125 self.assertTrue(m.spills >= 1) 126 self.assertEqual(sum(sum(v) for k, v in m.items()), 127 sum(xrange(self.N)) * 3) 128 129 def test_huge_dataset(self): 130 m = ExternalMerger(self.agg, 5, partitions=3) 131 m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) 132 self.assertTrue(m.spills >= 1) 133 self.assertEqual(sum(len(v) for k, v in m.items()), 134 self.N * 10) 135 m._cleanup() 136 137 def test_group_by_key(self): 138 139 def gen_data(N, step): 140 for i in range(1, N + 1, step): 141 for j in range(i): 142 yield (i, [j]) 143 144 def gen_gs(N, step=1): 145 return shuffle.GroupByKey(gen_data(N, step)) 146 147 self.assertEqual(1, len(list(gen_gs(1)))) 148 self.assertEqual(2, len(list(gen_gs(2)))) 149 self.assertEqual(100, len(list(gen_gs(100)))) 150 self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) 151 self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) 152 153 for k, vs in gen_gs(50002, 10000): 154 self.assertEqual(k, len(vs)) 155 self.assertEqual(list(range(k)), list(vs)) 156 157 ser = PickleSerializer() 158 l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) 159 for k, vs in l: 160 self.assertEqual(k, len(vs)) 161 self.assertEqual(list(range(k)), list(vs)) 162 163 164class SorterTests(unittest.TestCase): 165 def test_in_memory_sort(self): 166 l = list(range(1024)) 167 random.shuffle(l) 168 sorter = ExternalSorter(1024) 169 self.assertEqual(sorted(l), list(sorter.sorted(l))) 170 self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) 171 self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) 172 self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), 173 list(sorter.sorted(l, key=lambda x: -x, reverse=True))) 174 175 def test_external_sort(self): 176 class CustomizedSorter(ExternalSorter): 177 def _next_limit(self): 178 return self.memory_limit 179 l = list(range(1024)) 180 random.shuffle(l) 181 sorter = CustomizedSorter(1) 182 self.assertEqual(sorted(l), list(sorter.sorted(l))) 183 self.assertGreater(shuffle.DiskBytesSpilled, 0) 184 last = shuffle.DiskBytesSpilled 185 self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) 186 self.assertGreater(shuffle.DiskBytesSpilled, last) 187 last = shuffle.DiskBytesSpilled 188 self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) 189 self.assertGreater(shuffle.DiskBytesSpilled, last) 190 last = shuffle.DiskBytesSpilled 191 self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), 192 list(sorter.sorted(l, key=lambda x: -x, reverse=True))) 193 self.assertGreater(shuffle.DiskBytesSpilled, last) 194 195 def test_external_sort_in_rdd(self): 196 conf = SparkConf().set("spark.python.worker.memory", "1m") 197 sc = SparkContext(conf=conf) 198 l = list(range(10240)) 199 random.shuffle(l) 200 rdd = sc.parallelize(l, 4) 201 self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) 202 sc.stop() 203 204 205class SerializationTestCase(unittest.TestCase): 206 207 def test_namedtuple(self): 208 from collections import namedtuple 209 from pickle import dumps, loads 210 P = namedtuple("P", "x y") 211 p1 = P(1, 3) 212 p2 = loads(dumps(p1, 2)) 213 self.assertEqual(p1, p2) 214 215 from pyspark.cloudpickle import dumps 216 P2 = loads(dumps(P)) 217 p3 = P2(1, 3) 218 self.assertEqual(p1, p3) 219 220 def test_itemgetter(self): 221 from operator import itemgetter 222 ser = CloudPickleSerializer() 223 d = range(10) 224 getter = itemgetter(1) 225 getter2 = ser.loads(ser.dumps(getter)) 226 self.assertEqual(getter(d), getter2(d)) 227 228 getter = itemgetter(0, 3) 229 getter2 = ser.loads(ser.dumps(getter)) 230 self.assertEqual(getter(d), getter2(d)) 231 232 def test_function_module_name(self): 233 ser = CloudPickleSerializer() 234 func = lambda x: x 235 func2 = ser.loads(ser.dumps(func)) 236 self.assertEqual(func.__module__, func2.__module__) 237 238 def test_attrgetter(self): 239 from operator import attrgetter 240 ser = CloudPickleSerializer() 241 242 class C(object): 243 def __getattr__(self, item): 244 return item 245 d = C() 246 getter = attrgetter("a") 247 getter2 = ser.loads(ser.dumps(getter)) 248 self.assertEqual(getter(d), getter2(d)) 249 getter = attrgetter("a", "b") 250 getter2 = ser.loads(ser.dumps(getter)) 251 self.assertEqual(getter(d), getter2(d)) 252 253 d.e = C() 254 getter = attrgetter("e.a") 255 getter2 = ser.loads(ser.dumps(getter)) 256 self.assertEqual(getter(d), getter2(d)) 257 getter = attrgetter("e.a", "e.b") 258 getter2 = ser.loads(ser.dumps(getter)) 259 self.assertEqual(getter(d), getter2(d)) 260 261 # Regression test for SPARK-3415 262 def test_pickling_file_handles(self): 263 # to be corrected with SPARK-11160 264 if not xmlrunner: 265 ser = CloudPickleSerializer() 266 out1 = sys.stderr 267 out2 = ser.loads(ser.dumps(out1)) 268 self.assertEqual(out1, out2) 269 270 def test_func_globals(self): 271 272 class Unpicklable(object): 273 def __reduce__(self): 274 raise Exception("not picklable") 275 276 global exit 277 exit = Unpicklable() 278 279 ser = CloudPickleSerializer() 280 self.assertRaises(Exception, lambda: ser.dumps(exit)) 281 282 def foo(): 283 sys.exit(0) 284 285 self.assertTrue("exit" in foo.__code__.co_names) 286 ser.dumps(foo) 287 288 def test_compressed_serializer(self): 289 ser = CompressedSerializer(PickleSerializer()) 290 try: 291 from StringIO import StringIO 292 except ImportError: 293 from io import BytesIO as StringIO 294 io = StringIO() 295 ser.dump_stream(["abc", u"123", range(5)], io) 296 io.seek(0) 297 self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) 298 ser.dump_stream(range(1000), io) 299 io.seek(0) 300 self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) 301 io.close() 302 303 def test_hash_serializer(self): 304 hash(NoOpSerializer()) 305 hash(UTF8Deserializer()) 306 hash(PickleSerializer()) 307 hash(MarshalSerializer()) 308 hash(AutoSerializer()) 309 hash(BatchedSerializer(PickleSerializer())) 310 hash(AutoBatchedSerializer(MarshalSerializer())) 311 hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) 312 hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) 313 hash(CompressedSerializer(PickleSerializer())) 314 hash(FlattenedValuesSerializer(PickleSerializer())) 315 316 317class QuietTest(object): 318 def __init__(self, sc): 319 self.log4j = sc._jvm.org.apache.log4j 320 321 def __enter__(self): 322 self.old_level = self.log4j.LogManager.getRootLogger().getLevel() 323 self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) 324 325 def __exit__(self, exc_type, exc_val, exc_tb): 326 self.log4j.LogManager.getRootLogger().setLevel(self.old_level) 327 328 329class PySparkTestCase(unittest.TestCase): 330 331 def setUp(self): 332 self._old_sys_path = list(sys.path) 333 class_name = self.__class__.__name__ 334 self.sc = SparkContext('local[4]', class_name) 335 336 def tearDown(self): 337 self.sc.stop() 338 sys.path = self._old_sys_path 339 340 341class ReusedPySparkTestCase(unittest.TestCase): 342 343 @classmethod 344 def setUpClass(cls): 345 cls.sc = SparkContext('local[4]', cls.__name__) 346 347 @classmethod 348 def tearDownClass(cls): 349 cls.sc.stop() 350 351 352class CheckpointTests(ReusedPySparkTestCase): 353 354 def setUp(self): 355 self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) 356 os.unlink(self.checkpointDir.name) 357 self.sc.setCheckpointDir(self.checkpointDir.name) 358 359 def tearDown(self): 360 shutil.rmtree(self.checkpointDir.name) 361 362 def test_basic_checkpointing(self): 363 parCollection = self.sc.parallelize([1, 2, 3, 4]) 364 flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) 365 366 self.assertFalse(flatMappedRDD.isCheckpointed()) 367 self.assertTrue(flatMappedRDD.getCheckpointFile() is None) 368 369 flatMappedRDD.checkpoint() 370 result = flatMappedRDD.collect() 371 time.sleep(1) # 1 second 372 self.assertTrue(flatMappedRDD.isCheckpointed()) 373 self.assertEqual(flatMappedRDD.collect(), result) 374 self.assertEqual("file:" + self.checkpointDir.name, 375 os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) 376 377 def test_checkpoint_and_restore(self): 378 parCollection = self.sc.parallelize([1, 2, 3, 4]) 379 flatMappedRDD = parCollection.flatMap(lambda x: [x]) 380 381 self.assertFalse(flatMappedRDD.isCheckpointed()) 382 self.assertTrue(flatMappedRDD.getCheckpointFile() is None) 383 384 flatMappedRDD.checkpoint() 385 flatMappedRDD.count() # forces a checkpoint to be computed 386 time.sleep(1) # 1 second 387 388 self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) 389 recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), 390 flatMappedRDD._jrdd_deserializer) 391 self.assertEqual([1, 2, 3, 4], recovered.collect()) 392 393 394class LocalCheckpointTests(ReusedPySparkTestCase): 395 396 def test_basic_localcheckpointing(self): 397 parCollection = self.sc.parallelize([1, 2, 3, 4]) 398 flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) 399 400 self.assertFalse(flatMappedRDD.isCheckpointed()) 401 self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) 402 403 flatMappedRDD.localCheckpoint() 404 result = flatMappedRDD.collect() 405 time.sleep(1) # 1 second 406 self.assertTrue(flatMappedRDD.isCheckpointed()) 407 self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) 408 self.assertEqual(flatMappedRDD.collect(), result) 409 410 411class AddFileTests(PySparkTestCase): 412 413 def test_add_py_file(self): 414 # To ensure that we're actually testing addPyFile's effects, check that 415 # this job fails due to `userlibrary` not being on the Python path: 416 # disable logging in log4j temporarily 417 def func(x): 418 from userlibrary import UserClass 419 return UserClass().hello() 420 with QuietTest(self.sc): 421 self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) 422 423 # Add the file, so the job should now succeed: 424 path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") 425 self.sc.addPyFile(path) 426 res = self.sc.parallelize(range(2)).map(func).first() 427 self.assertEqual("Hello World!", res) 428 429 def test_add_file_locally(self): 430 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") 431 self.sc.addFile(path) 432 download_path = SparkFiles.get("hello.txt") 433 self.assertNotEqual(path, download_path) 434 with open(download_path) as test_file: 435 self.assertEqual("Hello World!\n", test_file.readline()) 436 437 def test_add_file_recursively_locally(self): 438 path = os.path.join(SPARK_HOME, "python/test_support/hello") 439 self.sc.addFile(path, True) 440 download_path = SparkFiles.get("hello") 441 self.assertNotEqual(path, download_path) 442 with open(download_path + "/hello.txt") as test_file: 443 self.assertEqual("Hello World!\n", test_file.readline()) 444 with open(download_path + "/sub_hello/sub_hello.txt") as test_file: 445 self.assertEqual("Sub Hello World!\n", test_file.readline()) 446 447 def test_add_py_file_locally(self): 448 # To ensure that we're actually testing addPyFile's effects, check that 449 # this fails due to `userlibrary` not being on the Python path: 450 def func(): 451 from userlibrary import UserClass 452 self.assertRaises(ImportError, func) 453 path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") 454 self.sc.addPyFile(path) 455 from userlibrary import UserClass 456 self.assertEqual("Hello World!", UserClass().hello()) 457 458 def test_add_egg_file_locally(self): 459 # To ensure that we're actually testing addPyFile's effects, check that 460 # this fails due to `userlibrary` not being on the Python path: 461 def func(): 462 from userlib import UserClass 463 self.assertRaises(ImportError, func) 464 path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") 465 self.sc.addPyFile(path) 466 from userlib import UserClass 467 self.assertEqual("Hello World from inside a package!", UserClass().hello()) 468 469 def test_overwrite_system_module(self): 470 self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) 471 472 import SimpleHTTPServer 473 self.assertEqual("My Server", SimpleHTTPServer.__name__) 474 475 def func(x): 476 import SimpleHTTPServer 477 return SimpleHTTPServer.__name__ 478 479 self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) 480 481 482class RDDTests(ReusedPySparkTestCase): 483 484 def test_range(self): 485 self.assertEqual(self.sc.range(1, 1).count(), 0) 486 self.assertEqual(self.sc.range(1, 0, -1).count(), 1) 487 self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) 488 489 def test_id(self): 490 rdd = self.sc.parallelize(range(10)) 491 id = rdd.id() 492 self.assertEqual(id, rdd.id()) 493 rdd2 = rdd.map(str).filter(bool) 494 id2 = rdd2.id() 495 self.assertEqual(id + 1, id2) 496 self.assertEqual(id2, rdd2.id()) 497 498 def test_empty_rdd(self): 499 rdd = self.sc.emptyRDD() 500 self.assertTrue(rdd.isEmpty()) 501 502 def test_sum(self): 503 self.assertEqual(0, self.sc.emptyRDD().sum()) 504 self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) 505 506 def test_to_localiterator(self): 507 from time import sleep 508 rdd = self.sc.parallelize([1, 2, 3]) 509 it = rdd.toLocalIterator() 510 sleep(5) 511 self.assertEqual([1, 2, 3], sorted(it)) 512 513 rdd2 = rdd.repartition(1000) 514 it2 = rdd2.toLocalIterator() 515 sleep(5) 516 self.assertEqual([1, 2, 3], sorted(it2)) 517 518 def test_save_as_textfile_with_unicode(self): 519 # Regression test for SPARK-970 520 x = u"\u00A1Hola, mundo!" 521 data = self.sc.parallelize([x]) 522 tempFile = tempfile.NamedTemporaryFile(delete=True) 523 tempFile.close() 524 data.saveAsTextFile(tempFile.name) 525 raw_contents = b''.join(open(p, 'rb').read() 526 for p in glob(tempFile.name + "/part-0000*")) 527 self.assertEqual(x, raw_contents.strip().decode("utf-8")) 528 529 def test_save_as_textfile_with_utf8(self): 530 x = u"\u00A1Hola, mundo!" 531 data = self.sc.parallelize([x.encode("utf-8")]) 532 tempFile = tempfile.NamedTemporaryFile(delete=True) 533 tempFile.close() 534 data.saveAsTextFile(tempFile.name) 535 raw_contents = b''.join(open(p, 'rb').read() 536 for p in glob(tempFile.name + "/part-0000*")) 537 self.assertEqual(x, raw_contents.strip().decode('utf8')) 538 539 def test_transforming_cartesian_result(self): 540 # Regression test for SPARK-1034 541 rdd1 = self.sc.parallelize([1, 2]) 542 rdd2 = self.sc.parallelize([3, 4]) 543 cart = rdd1.cartesian(rdd2) 544 result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() 545 546 def test_transforming_pickle_file(self): 547 # Regression test for SPARK-2601 548 data = self.sc.parallelize([u"Hello", u"World!"]) 549 tempFile = tempfile.NamedTemporaryFile(delete=True) 550 tempFile.close() 551 data.saveAsPickleFile(tempFile.name) 552 pickled_file = self.sc.pickleFile(tempFile.name) 553 pickled_file.map(lambda x: x).collect() 554 555 def test_cartesian_on_textfile(self): 556 # Regression test for 557 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") 558 a = self.sc.textFile(path) 559 result = a.cartesian(a).collect() 560 (x, y) = result[0] 561 self.assertEqual(u"Hello World!", x.strip()) 562 self.assertEqual(u"Hello World!", y.strip()) 563 564 def test_cartesian_chaining(self): 565 # Tests for SPARK-16589 566 rdd = self.sc.parallelize(range(10), 2) 567 self.assertSetEqual( 568 set(rdd.cartesian(rdd).cartesian(rdd).collect()), 569 set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) 570 ) 571 572 self.assertSetEqual( 573 set(rdd.cartesian(rdd.cartesian(rdd)).collect()), 574 set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) 575 ) 576 577 self.assertSetEqual( 578 set(rdd.cartesian(rdd.zip(rdd)).collect()), 579 set([(x, (y, y)) for x in range(10) for y in range(10)]) 580 ) 581 582 def test_deleting_input_files(self): 583 # Regression test for SPARK-1025 584 tempFile = tempfile.NamedTemporaryFile(delete=False) 585 tempFile.write(b"Hello World!") 586 tempFile.close() 587 data = self.sc.textFile(tempFile.name) 588 filtered_data = data.filter(lambda x: True) 589 self.assertEqual(1, filtered_data.count()) 590 os.unlink(tempFile.name) 591 with QuietTest(self.sc): 592 self.assertRaises(Exception, lambda: filtered_data.count()) 593 594 def test_sampling_default_seed(self): 595 # Test for SPARK-3995 (default seed setting) 596 data = self.sc.parallelize(xrange(1000), 1) 597 subset = data.takeSample(False, 10) 598 self.assertEqual(len(subset), 10) 599 600 def test_aggregate_mutable_zero_value(self): 601 # Test for SPARK-9021; uses aggregate and treeAggregate to build dict 602 # representing a counter of ints 603 # NOTE: dict is used instead of collections.Counter for Python 2.6 604 # compatibility 605 from collections import defaultdict 606 607 # Show that single or multiple partitions work 608 data1 = self.sc.range(10, numSlices=1) 609 data2 = self.sc.range(10, numSlices=2) 610 611 def seqOp(x, y): 612 x[y] += 1 613 return x 614 615 def comboOp(x, y): 616 for key, val in y.items(): 617 x[key] += val 618 return x 619 620 counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) 621 counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) 622 counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) 623 counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) 624 625 ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) 626 self.assertEqual(counts1, ground_truth) 627 self.assertEqual(counts2, ground_truth) 628 self.assertEqual(counts3, ground_truth) 629 self.assertEqual(counts4, ground_truth) 630 631 def test_aggregate_by_key_mutable_zero_value(self): 632 # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that 633 # contains lists of all values for each key in the original RDD 634 635 # list(range(...)) for Python 3.x compatibility (can't use * operator 636 # on a range object) 637 # list(zip(...)) for Python 3.x compatibility (want to parallelize a 638 # collection, not a zip object) 639 tuples = list(zip(list(range(10))*2, [1]*20)) 640 # Show that single or multiple partitions work 641 data1 = self.sc.parallelize(tuples, 1) 642 data2 = self.sc.parallelize(tuples, 2) 643 644 def seqOp(x, y): 645 x.append(y) 646 return x 647 648 def comboOp(x, y): 649 x.extend(y) 650 return x 651 652 values1 = data1.aggregateByKey([], seqOp, comboOp).collect() 653 values2 = data2.aggregateByKey([], seqOp, comboOp).collect() 654 # Sort lists to ensure clean comparison with ground_truth 655 values1.sort() 656 values2.sort() 657 658 ground_truth = [(i, [1]*2) for i in range(10)] 659 self.assertEqual(values1, ground_truth) 660 self.assertEqual(values2, ground_truth) 661 662 def test_fold_mutable_zero_value(self): 663 # Test for SPARK-9021; uses fold to merge an RDD of dict counters into 664 # a single dict 665 # NOTE: dict is used instead of collections.Counter for Python 2.6 666 # compatibility 667 from collections import defaultdict 668 669 counts1 = defaultdict(int, dict((i, 1) for i in range(10))) 670 counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) 671 counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) 672 counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) 673 all_counts = [counts1, counts2, counts3, counts4] 674 # Show that single or multiple partitions work 675 data1 = self.sc.parallelize(all_counts, 1) 676 data2 = self.sc.parallelize(all_counts, 2) 677 678 def comboOp(x, y): 679 for key, val in y.items(): 680 x[key] += val 681 return x 682 683 fold1 = data1.fold(defaultdict(int), comboOp) 684 fold2 = data2.fold(defaultdict(int), comboOp) 685 686 ground_truth = defaultdict(int) 687 for counts in all_counts: 688 for key, val in counts.items(): 689 ground_truth[key] += val 690 self.assertEqual(fold1, ground_truth) 691 self.assertEqual(fold2, ground_truth) 692 693 def test_fold_by_key_mutable_zero_value(self): 694 # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains 695 # lists of all values for each key in the original RDD 696 697 tuples = [(i, range(i)) for i in range(10)]*2 698 # Show that single or multiple partitions work 699 data1 = self.sc.parallelize(tuples, 1) 700 data2 = self.sc.parallelize(tuples, 2) 701 702 def comboOp(x, y): 703 x.extend(y) 704 return x 705 706 values1 = data1.foldByKey([], comboOp).collect() 707 values2 = data2.foldByKey([], comboOp).collect() 708 # Sort lists to ensure clean comparison with ground_truth 709 values1.sort() 710 values2.sort() 711 712 # list(range(...)) for Python 3.x compatibility 713 ground_truth = [(i, list(range(i))*2) for i in range(10)] 714 self.assertEqual(values1, ground_truth) 715 self.assertEqual(values2, ground_truth) 716 717 def test_aggregate_by_key(self): 718 data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) 719 720 def seqOp(x, y): 721 x.add(y) 722 return x 723 724 def combOp(x, y): 725 x |= y 726 return x 727 728 sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) 729 self.assertEqual(3, len(sets)) 730 self.assertEqual(set([1]), sets[1]) 731 self.assertEqual(set([2]), sets[3]) 732 self.assertEqual(set([1, 3]), sets[5]) 733 734 def test_itemgetter(self): 735 rdd = self.sc.parallelize([range(10)]) 736 from operator import itemgetter 737 self.assertEqual([1], rdd.map(itemgetter(1)).collect()) 738 self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) 739 740 def test_namedtuple_in_rdd(self): 741 from collections import namedtuple 742 Person = namedtuple("Person", "id firstName lastName") 743 jon = Person(1, "Jon", "Doe") 744 jane = Person(2, "Jane", "Doe") 745 theDoes = self.sc.parallelize([jon, jane]) 746 self.assertEqual([jon, jane], theDoes.collect()) 747 748 def test_large_broadcast(self): 749 N = 10000 750 data = [[float(i) for i in range(300)] for i in range(N)] 751 bdata = self.sc.broadcast(data) # 27MB 752 m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() 753 self.assertEqual(N, m) 754 755 def test_unpersist(self): 756 N = 1000 757 data = [[float(i) for i in range(300)] for i in range(N)] 758 bdata = self.sc.broadcast(data) # 3MB 759 bdata.unpersist() 760 m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() 761 self.assertEqual(N, m) 762 bdata.destroy() 763 try: 764 self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() 765 except Exception as e: 766 pass 767 else: 768 raise Exception("job should fail after destroy the broadcast") 769 770 def test_multiple_broadcasts(self): 771 N = 1 << 21 772 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM 773 r = list(range(1 << 15)) 774 random.shuffle(r) 775 s = str(r).encode() 776 checksum = hashlib.md5(s).hexdigest() 777 b2 = self.sc.broadcast(s) 778 r = list(set(self.sc.parallelize(range(10), 10).map( 779 lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) 780 self.assertEqual(1, len(r)) 781 size, csum = r[0] 782 self.assertEqual(N, size) 783 self.assertEqual(checksum, csum) 784 785 random.shuffle(r) 786 s = str(r).encode() 787 checksum = hashlib.md5(s).hexdigest() 788 b2 = self.sc.broadcast(s) 789 r = list(set(self.sc.parallelize(range(10), 10).map( 790 lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) 791 self.assertEqual(1, len(r)) 792 size, csum = r[0] 793 self.assertEqual(N, size) 794 self.assertEqual(checksum, csum) 795 796 def test_large_closure(self): 797 N = 200000 798 data = [float(i) for i in xrange(N)] 799 rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) 800 self.assertEqual(N, rdd.first()) 801 # regression test for SPARK-6886 802 self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) 803 804 def test_zip_with_different_serializers(self): 805 a = self.sc.parallelize(range(5)) 806 b = self.sc.parallelize(range(100, 105)) 807 self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) 808 a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) 809 b = b._reserialize(MarshalSerializer()) 810 self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) 811 # regression test for SPARK-4841 812 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") 813 t = self.sc.textFile(path) 814 cnt = t.count() 815 self.assertEqual(cnt, t.zip(t).count()) 816 rdd = t.map(str) 817 self.assertEqual(cnt, t.zip(rdd).count()) 818 # regression test for bug in _reserializer() 819 self.assertEqual(cnt, t.zip(rdd).count()) 820 821 def test_zip_with_different_object_sizes(self): 822 # regress test for SPARK-5973 823 a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) 824 b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) 825 self.assertEqual(10000, a.zip(b).count()) 826 827 def test_zip_with_different_number_of_items(self): 828 a = self.sc.parallelize(range(5), 2) 829 # different number of partitions 830 b = self.sc.parallelize(range(100, 106), 3) 831 self.assertRaises(ValueError, lambda: a.zip(b)) 832 with QuietTest(self.sc): 833 # different number of batched items in JVM 834 b = self.sc.parallelize(range(100, 104), 2) 835 self.assertRaises(Exception, lambda: a.zip(b).count()) 836 # different number of items in one pair 837 b = self.sc.parallelize(range(100, 106), 2) 838 self.assertRaises(Exception, lambda: a.zip(b).count()) 839 # same total number of items, but different distributions 840 a = self.sc.parallelize([2, 3], 2).flatMap(range) 841 b = self.sc.parallelize([3, 2], 2).flatMap(range) 842 self.assertEqual(a.count(), b.count()) 843 self.assertRaises(Exception, lambda: a.zip(b).count()) 844 845 def test_count_approx_distinct(self): 846 rdd = self.sc.parallelize(xrange(1000)) 847 self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) 848 self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) 849 self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) 850 self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) 851 852 rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) 853 self.assertTrue(18 < rdd.countApproxDistinct() < 22) 854 self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) 855 self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) 856 self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) 857 858 self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) 859 860 def test_histogram(self): 861 # empty 862 rdd = self.sc.parallelize([]) 863 self.assertEqual([0], rdd.histogram([0, 10])[1]) 864 self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) 865 self.assertRaises(ValueError, lambda: rdd.histogram(1)) 866 867 # out of range 868 rdd = self.sc.parallelize([10.01, -0.01]) 869 self.assertEqual([0], rdd.histogram([0, 10])[1]) 870 self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) 871 872 # in range with one bucket 873 rdd = self.sc.parallelize(range(1, 5)) 874 self.assertEqual([4], rdd.histogram([0, 10])[1]) 875 self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) 876 877 # in range with one bucket exact match 878 self.assertEqual([4], rdd.histogram([1, 4])[1]) 879 880 # out of range with two buckets 881 rdd = self.sc.parallelize([10.01, -0.01]) 882 self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) 883 884 # out of range with two uneven buckets 885 rdd = self.sc.parallelize([10.01, -0.01]) 886 self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) 887 888 # in range with two buckets 889 rdd = self.sc.parallelize([1, 2, 3, 5, 6]) 890 self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) 891 892 # in range with two bucket and None 893 rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) 894 self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) 895 896 # in range with two uneven buckets 897 rdd = self.sc.parallelize([1, 2, 3, 5, 6]) 898 self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) 899 900 # mixed range with two uneven buckets 901 rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) 902 self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) 903 904 # mixed range with four uneven buckets 905 rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) 906 self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) 907 908 # mixed range with uneven buckets and NaN 909 rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 910 199.0, 200.0, 200.1, None, float('nan')]) 911 self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) 912 913 # out of range with infinite buckets 914 rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) 915 self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) 916 917 # invalid buckets 918 self.assertRaises(ValueError, lambda: rdd.histogram([])) 919 self.assertRaises(ValueError, lambda: rdd.histogram([1])) 920 self.assertRaises(ValueError, lambda: rdd.histogram(0)) 921 self.assertRaises(TypeError, lambda: rdd.histogram({})) 922 923 # without buckets 924 rdd = self.sc.parallelize(range(1, 5)) 925 self.assertEqual(([1, 4], [4]), rdd.histogram(1)) 926 927 # without buckets single element 928 rdd = self.sc.parallelize([1]) 929 self.assertEqual(([1, 1], [1]), rdd.histogram(1)) 930 931 # without bucket no range 932 rdd = self.sc.parallelize([1] * 4) 933 self.assertEqual(([1, 1], [4]), rdd.histogram(1)) 934 935 # without buckets basic two 936 rdd = self.sc.parallelize(range(1, 5)) 937 self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) 938 939 # without buckets with more requested than elements 940 rdd = self.sc.parallelize([1, 2]) 941 buckets = [1 + 0.2 * i for i in range(6)] 942 hist = [1, 0, 0, 0, 1] 943 self.assertEqual((buckets, hist), rdd.histogram(5)) 944 945 # invalid RDDs 946 rdd = self.sc.parallelize([1, float('inf')]) 947 self.assertRaises(ValueError, lambda: rdd.histogram(2)) 948 rdd = self.sc.parallelize([float('nan')]) 949 self.assertRaises(ValueError, lambda: rdd.histogram(2)) 950 951 # string 952 rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) 953 self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) 954 self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) 955 self.assertRaises(TypeError, lambda: rdd.histogram(2)) 956 957 def test_repartitionAndSortWithinPartitions(self): 958 rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) 959 960 repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) 961 partitions = repartitioned.glom().collect() 962 self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) 963 self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) 964 965 def test_repartition_no_skewed(self): 966 num_partitions = 20 967 a = self.sc.parallelize(range(int(1000)), 2) 968 l = a.repartition(num_partitions).glom().map(len).collect() 969 zeros = len([x for x in l if x == 0]) 970 self.assertTrue(zeros == 0) 971 l = a.coalesce(num_partitions, True).glom().map(len).collect() 972 zeros = len([x for x in l if x == 0]) 973 self.assertTrue(zeros == 0) 974 975 def test_repartition_on_textfile(self): 976 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") 977 rdd = self.sc.textFile(path) 978 result = rdd.repartition(1).collect() 979 self.assertEqual(u"Hello World!", result[0]) 980 981 def test_distinct(self): 982 rdd = self.sc.parallelize((1, 2, 3)*10, 10) 983 self.assertEqual(rdd.getNumPartitions(), 10) 984 self.assertEqual(rdd.distinct().count(), 3) 985 result = rdd.distinct(5) 986 self.assertEqual(result.getNumPartitions(), 5) 987 self.assertEqual(result.count(), 3) 988 989 def test_external_group_by_key(self): 990 self.sc._conf.set("spark.python.worker.memory", "1m") 991 N = 200001 992 kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) 993 gkv = kv.groupByKey().cache() 994 self.assertEqual(3, gkv.count()) 995 filtered = gkv.filter(lambda kv: kv[0] == 1) 996 self.assertEqual(1, filtered.count()) 997 self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) 998 self.assertEqual([(N // 3, N // 3)], 999 filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) 1000 result = filtered.collect()[0][1] 1001 self.assertEqual(N // 3, len(result)) 1002 self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) 1003 1004 def test_sort_on_empty_rdd(self): 1005 self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) 1006 1007 def test_sample(self): 1008 rdd = self.sc.parallelize(range(0, 100), 4) 1009 wo = rdd.sample(False, 0.1, 2).collect() 1010 wo_dup = rdd.sample(False, 0.1, 2).collect() 1011 self.assertSetEqual(set(wo), set(wo_dup)) 1012 wr = rdd.sample(True, 0.2, 5).collect() 1013 wr_dup = rdd.sample(True, 0.2, 5).collect() 1014 self.assertSetEqual(set(wr), set(wr_dup)) 1015 wo_s10 = rdd.sample(False, 0.3, 10).collect() 1016 wo_s20 = rdd.sample(False, 0.3, 20).collect() 1017 self.assertNotEqual(set(wo_s10), set(wo_s20)) 1018 wr_s11 = rdd.sample(True, 0.4, 11).collect() 1019 wr_s21 = rdd.sample(True, 0.4, 21).collect() 1020 self.assertNotEqual(set(wr_s11), set(wr_s21)) 1021 1022 def test_null_in_rdd(self): 1023 jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) 1024 rdd = RDD(jrdd, self.sc, UTF8Deserializer()) 1025 self.assertEqual([u"a", None, u"b"], rdd.collect()) 1026 rdd = RDD(jrdd, self.sc, NoOpSerializer()) 1027 self.assertEqual([b"a", None, b"b"], rdd.collect()) 1028 1029 def test_multiple_python_java_RDD_conversions(self): 1030 # Regression test for SPARK-5361 1031 data = [ 1032 (u'1', {u'director': u'David Lean'}), 1033 (u'2', {u'director': u'Andrew Dominik'}) 1034 ] 1035 data_rdd = self.sc.parallelize(data) 1036 data_java_rdd = data_rdd._to_java_object_rdd() 1037 data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) 1038 converted_rdd = RDD(data_python_rdd, self.sc) 1039 self.assertEqual(2, converted_rdd.count()) 1040 1041 # conversion between python and java RDD threw exceptions 1042 data_java_rdd = converted_rdd._to_java_object_rdd() 1043 data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) 1044 converted_rdd = RDD(data_python_rdd, self.sc) 1045 self.assertEqual(2, converted_rdd.count()) 1046 1047 def test_narrow_dependency_in_join(self): 1048 rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) 1049 parted = rdd.partitionBy(2) 1050 self.assertEqual(2, parted.union(parted).getNumPartitions()) 1051 self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) 1052 self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) 1053 1054 tracker = self.sc.statusTracker() 1055 1056 self.sc.setJobGroup("test1", "test", True) 1057 d = sorted(parted.join(parted).collect()) 1058 self.assertEqual(10, len(d)) 1059 self.assertEqual((0, (0, 0)), d[0]) 1060 jobId = tracker.getJobIdsForGroup("test1")[0] 1061 self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) 1062 1063 self.sc.setJobGroup("test2", "test", True) 1064 d = sorted(parted.join(rdd).collect()) 1065 self.assertEqual(10, len(d)) 1066 self.assertEqual((0, (0, 0)), d[0]) 1067 jobId = tracker.getJobIdsForGroup("test2")[0] 1068 self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) 1069 1070 self.sc.setJobGroup("test3", "test", True) 1071 d = sorted(parted.cogroup(parted).collect()) 1072 self.assertEqual(10, len(d)) 1073 self.assertEqual([[0], [0]], list(map(list, d[0][1]))) 1074 jobId = tracker.getJobIdsForGroup("test3")[0] 1075 self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) 1076 1077 self.sc.setJobGroup("test4", "test", True) 1078 d = sorted(parted.cogroup(rdd).collect()) 1079 self.assertEqual(10, len(d)) 1080 self.assertEqual([[0], [0]], list(map(list, d[0][1]))) 1081 jobId = tracker.getJobIdsForGroup("test4")[0] 1082 self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) 1083 1084 # Regression test for SPARK-6294 1085 def test_take_on_jrdd(self): 1086 rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) 1087 rdd._jrdd.first() 1088 1089 def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): 1090 # Regression test for SPARK-5969 1091 seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence 1092 rdd = self.sc.parallelize(seq) 1093 for ascending in [True, False]: 1094 sort = rdd.sortByKey(ascending=ascending, numPartitions=5) 1095 self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) 1096 sizes = sort.glom().map(len).collect() 1097 for size in sizes: 1098 self.assertGreater(size, 0) 1099 1100 def test_pipe_functions(self): 1101 data = ['1', '2', '3'] 1102 rdd = self.sc.parallelize(data) 1103 with QuietTest(self.sc): 1104 self.assertEqual([], rdd.pipe('cc').collect()) 1105 self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) 1106 result = rdd.pipe('cat').collect() 1107 result.sort() 1108 for x, y in zip(data, result): 1109 self.assertEqual(x, y) 1110 self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) 1111 self.assertEqual([], rdd.pipe('grep 4').collect()) 1112 1113 1114class ProfilerTests(PySparkTestCase): 1115 1116 def setUp(self): 1117 self._old_sys_path = list(sys.path) 1118 class_name = self.__class__.__name__ 1119 conf = SparkConf().set("spark.python.profile", "true") 1120 self.sc = SparkContext('local[4]', class_name, conf=conf) 1121 1122 def test_profiler(self): 1123 self.do_computation() 1124 1125 profilers = self.sc.profiler_collector.profilers 1126 self.assertEqual(1, len(profilers)) 1127 id, profiler, _ = profilers[0] 1128 stats = profiler.stats() 1129 self.assertTrue(stats is not None) 1130 width, stat_list = stats.get_print_list([]) 1131 func_names = [func_name for fname, n, func_name in stat_list] 1132 self.assertTrue("heavy_foo" in func_names) 1133 1134 old_stdout = sys.stdout 1135 sys.stdout = io = StringIO() 1136 self.sc.show_profiles() 1137 self.assertTrue("heavy_foo" in io.getvalue()) 1138 sys.stdout = old_stdout 1139 1140 d = tempfile.gettempdir() 1141 self.sc.dump_profiles(d) 1142 self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) 1143 1144 def test_custom_profiler(self): 1145 class TestCustomProfiler(BasicProfiler): 1146 def show(self, id): 1147 self.result = "Custom formatting" 1148 1149 self.sc.profiler_collector.profiler_cls = TestCustomProfiler 1150 1151 self.do_computation() 1152 1153 profilers = self.sc.profiler_collector.profilers 1154 self.assertEqual(1, len(profilers)) 1155 _, profiler, _ = profilers[0] 1156 self.assertTrue(isinstance(profiler, TestCustomProfiler)) 1157 1158 self.sc.show_profiles() 1159 self.assertEqual("Custom formatting", profiler.result) 1160 1161 def do_computation(self): 1162 def heavy_foo(x): 1163 for i in range(1 << 18): 1164 x = 1 1165 1166 rdd = self.sc.parallelize(range(100)) 1167 rdd.foreach(heavy_foo) 1168 1169 1170class InputFormatTests(ReusedPySparkTestCase): 1171 1172 @classmethod 1173 def setUpClass(cls): 1174 ReusedPySparkTestCase.setUpClass() 1175 cls.tempdir = tempfile.NamedTemporaryFile(delete=False) 1176 os.unlink(cls.tempdir.name) 1177 cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) 1178 1179 @classmethod 1180 def tearDownClass(cls): 1181 ReusedPySparkTestCase.tearDownClass() 1182 shutil.rmtree(cls.tempdir.name) 1183 1184 @unittest.skipIf(sys.version >= "3", "serialize array of byte") 1185 def test_sequencefiles(self): 1186 basepath = self.tempdir.name 1187 ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", 1188 "org.apache.hadoop.io.IntWritable", 1189 "org.apache.hadoop.io.Text").collect()) 1190 ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] 1191 self.assertEqual(ints, ei) 1192 1193 doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", 1194 "org.apache.hadoop.io.DoubleWritable", 1195 "org.apache.hadoop.io.Text").collect()) 1196 ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] 1197 self.assertEqual(doubles, ed) 1198 1199 bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", 1200 "org.apache.hadoop.io.IntWritable", 1201 "org.apache.hadoop.io.BytesWritable").collect()) 1202 ebs = [(1, bytearray('aa', 'utf-8')), 1203 (1, bytearray('aa', 'utf-8')), 1204 (2, bytearray('aa', 'utf-8')), 1205 (2, bytearray('bb', 'utf-8')), 1206 (2, bytearray('bb', 'utf-8')), 1207 (3, bytearray('cc', 'utf-8'))] 1208 self.assertEqual(bytes, ebs) 1209 1210 text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", 1211 "org.apache.hadoop.io.Text", 1212 "org.apache.hadoop.io.Text").collect()) 1213 et = [(u'1', u'aa'), 1214 (u'1', u'aa'), 1215 (u'2', u'aa'), 1216 (u'2', u'bb'), 1217 (u'2', u'bb'), 1218 (u'3', u'cc')] 1219 self.assertEqual(text, et) 1220 1221 bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", 1222 "org.apache.hadoop.io.IntWritable", 1223 "org.apache.hadoop.io.BooleanWritable").collect()) 1224 eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] 1225 self.assertEqual(bools, eb) 1226 1227 nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", 1228 "org.apache.hadoop.io.IntWritable", 1229 "org.apache.hadoop.io.BooleanWritable").collect()) 1230 en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] 1231 self.assertEqual(nulls, en) 1232 1233 maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", 1234 "org.apache.hadoop.io.IntWritable", 1235 "org.apache.hadoop.io.MapWritable").collect() 1236 em = [(1, {}), 1237 (1, {3.0: u'bb'}), 1238 (2, {1.0: u'aa'}), 1239 (2, {1.0: u'cc'}), 1240 (3, {2.0: u'dd'})] 1241 for v in maps: 1242 self.assertTrue(v in em) 1243 1244 # arrays get pickled to tuples by default 1245 tuples = sorted(self.sc.sequenceFile( 1246 basepath + "/sftestdata/sfarray/", 1247 "org.apache.hadoop.io.IntWritable", 1248 "org.apache.spark.api.python.DoubleArrayWritable").collect()) 1249 et = [(1, ()), 1250 (2, (3.0, 4.0, 5.0)), 1251 (3, (4.0, 5.0, 6.0))] 1252 self.assertEqual(tuples, et) 1253 1254 # with custom converters, primitive arrays can stay as arrays 1255 arrays = sorted(self.sc.sequenceFile( 1256 basepath + "/sftestdata/sfarray/", 1257 "org.apache.hadoop.io.IntWritable", 1258 "org.apache.spark.api.python.DoubleArrayWritable", 1259 valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) 1260 ea = [(1, array('d')), 1261 (2, array('d', [3.0, 4.0, 5.0])), 1262 (3, array('d', [4.0, 5.0, 6.0]))] 1263 self.assertEqual(arrays, ea) 1264 1265 clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", 1266 "org.apache.hadoop.io.Text", 1267 "org.apache.spark.api.python.TestWritable").collect()) 1268 cname = u'org.apache.spark.api.python.TestWritable' 1269 ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), 1270 (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), 1271 (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), 1272 (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), 1273 (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] 1274 self.assertEqual(clazz, ec) 1275 1276 unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", 1277 "org.apache.hadoop.io.Text", 1278 "org.apache.spark.api.python.TestWritable", 1279 ).collect()) 1280 self.assertEqual(unbatched_clazz, ec) 1281 1282 def test_oldhadoop(self): 1283 basepath = self.tempdir.name 1284 ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", 1285 "org.apache.hadoop.mapred.SequenceFileInputFormat", 1286 "org.apache.hadoop.io.IntWritable", 1287 "org.apache.hadoop.io.Text").collect()) 1288 ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] 1289 self.assertEqual(ints, ei) 1290 1291 hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") 1292 oldconf = {"mapred.input.dir": hellopath} 1293 hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", 1294 "org.apache.hadoop.io.LongWritable", 1295 "org.apache.hadoop.io.Text", 1296 conf=oldconf).collect() 1297 result = [(0, u'Hello World!')] 1298 self.assertEqual(hello, result) 1299 1300 def test_newhadoop(self): 1301 basepath = self.tempdir.name 1302 ints = sorted(self.sc.newAPIHadoopFile( 1303 basepath + "/sftestdata/sfint/", 1304 "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", 1305 "org.apache.hadoop.io.IntWritable", 1306 "org.apache.hadoop.io.Text").collect()) 1307 ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] 1308 self.assertEqual(ints, ei) 1309 1310 hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") 1311 newconf = {"mapred.input.dir": hellopath} 1312 hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", 1313 "org.apache.hadoop.io.LongWritable", 1314 "org.apache.hadoop.io.Text", 1315 conf=newconf).collect() 1316 result = [(0, u'Hello World!')] 1317 self.assertEqual(hello, result) 1318 1319 def test_newolderror(self): 1320 basepath = self.tempdir.name 1321 self.assertRaises(Exception, lambda: self.sc.hadoopFile( 1322 basepath + "/sftestdata/sfint/", 1323 "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", 1324 "org.apache.hadoop.io.IntWritable", 1325 "org.apache.hadoop.io.Text")) 1326 1327 self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( 1328 basepath + "/sftestdata/sfint/", 1329 "org.apache.hadoop.mapred.SequenceFileInputFormat", 1330 "org.apache.hadoop.io.IntWritable", 1331 "org.apache.hadoop.io.Text")) 1332 1333 def test_bad_inputs(self): 1334 basepath = self.tempdir.name 1335 self.assertRaises(Exception, lambda: self.sc.sequenceFile( 1336 basepath + "/sftestdata/sfint/", 1337 "org.apache.hadoop.io.NotValidWritable", 1338 "org.apache.hadoop.io.Text")) 1339 self.assertRaises(Exception, lambda: self.sc.hadoopFile( 1340 basepath + "/sftestdata/sfint/", 1341 "org.apache.hadoop.mapred.NotValidInputFormat", 1342 "org.apache.hadoop.io.IntWritable", 1343 "org.apache.hadoop.io.Text")) 1344 self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( 1345 basepath + "/sftestdata/sfint/", 1346 "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", 1347 "org.apache.hadoop.io.IntWritable", 1348 "org.apache.hadoop.io.Text")) 1349 1350 def test_converters(self): 1351 # use of custom converters 1352 basepath = self.tempdir.name 1353 maps = sorted(self.sc.sequenceFile( 1354 basepath + "/sftestdata/sfmap/", 1355 "org.apache.hadoop.io.IntWritable", 1356 "org.apache.hadoop.io.MapWritable", 1357 keyConverter="org.apache.spark.api.python.TestInputKeyConverter", 1358 valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) 1359 em = [(u'\x01', []), 1360 (u'\x01', [3.0]), 1361 (u'\x02', [1.0]), 1362 (u'\x02', [1.0]), 1363 (u'\x03', [2.0])] 1364 self.assertEqual(maps, em) 1365 1366 def test_binary_files(self): 1367 path = os.path.join(self.tempdir.name, "binaryfiles") 1368 os.mkdir(path) 1369 data = b"short binary data" 1370 with open(os.path.join(path, "part-0000"), 'wb') as f: 1371 f.write(data) 1372 [(p, d)] = self.sc.binaryFiles(path).collect() 1373 self.assertTrue(p.endswith("part-0000")) 1374 self.assertEqual(d, data) 1375 1376 def test_binary_records(self): 1377 path = os.path.join(self.tempdir.name, "binaryrecords") 1378 os.mkdir(path) 1379 with open(os.path.join(path, "part-0000"), 'w') as f: 1380 for i in range(100): 1381 f.write('%04d' % i) 1382 result = self.sc.binaryRecords(path, 4).map(int).collect() 1383 self.assertEqual(list(range(100)), result) 1384 1385 1386class OutputFormatTests(ReusedPySparkTestCase): 1387 1388 def setUp(self): 1389 self.tempdir = tempfile.NamedTemporaryFile(delete=False) 1390 os.unlink(self.tempdir.name) 1391 1392 def tearDown(self): 1393 shutil.rmtree(self.tempdir.name, ignore_errors=True) 1394 1395 @unittest.skipIf(sys.version >= "3", "serialize array of byte") 1396 def test_sequencefiles(self): 1397 basepath = self.tempdir.name 1398 ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] 1399 self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") 1400 ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) 1401 self.assertEqual(ints, ei) 1402 1403 ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] 1404 self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") 1405 doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) 1406 self.assertEqual(doubles, ed) 1407 1408 ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] 1409 self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") 1410 bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) 1411 self.assertEqual(bytes, ebs) 1412 1413 et = [(u'1', u'aa'), 1414 (u'2', u'bb'), 1415 (u'3', u'cc')] 1416 self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") 1417 text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) 1418 self.assertEqual(text, et) 1419 1420 eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] 1421 self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") 1422 bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) 1423 self.assertEqual(bools, eb) 1424 1425 en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] 1426 self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") 1427 nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) 1428 self.assertEqual(nulls, en) 1429 1430 em = [(1, {}), 1431 (1, {3.0: u'bb'}), 1432 (2, {1.0: u'aa'}), 1433 (2, {1.0: u'cc'}), 1434 (3, {2.0: u'dd'})] 1435 self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") 1436 maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() 1437 for v in maps: 1438 self.assertTrue(v, em) 1439 1440 def test_oldhadoop(self): 1441 basepath = self.tempdir.name 1442 dict_data = [(1, {}), 1443 (1, {"row1": 1.0}), 1444 (2, {"row2": 2.0})] 1445 self.sc.parallelize(dict_data).saveAsHadoopFile( 1446 basepath + "/oldhadoop/", 1447 "org.apache.hadoop.mapred.SequenceFileOutputFormat", 1448 "org.apache.hadoop.io.IntWritable", 1449 "org.apache.hadoop.io.MapWritable") 1450 result = self.sc.hadoopFile( 1451 basepath + "/oldhadoop/", 1452 "org.apache.hadoop.mapred.SequenceFileInputFormat", 1453 "org.apache.hadoop.io.IntWritable", 1454 "org.apache.hadoop.io.MapWritable").collect() 1455 for v in result: 1456 self.assertTrue(v, dict_data) 1457 1458 conf = { 1459 "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", 1460 "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", 1461 "mapred.output.value.class": "org.apache.hadoop.io.MapWritable", 1462 "mapred.output.dir": basepath + "/olddataset/" 1463 } 1464 self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) 1465 input_conf = {"mapred.input.dir": basepath + "/olddataset/"} 1466 result = self.sc.hadoopRDD( 1467 "org.apache.hadoop.mapred.SequenceFileInputFormat", 1468 "org.apache.hadoop.io.IntWritable", 1469 "org.apache.hadoop.io.MapWritable", 1470 conf=input_conf).collect() 1471 for v in result: 1472 self.assertTrue(v, dict_data) 1473 1474 def test_newhadoop(self): 1475 basepath = self.tempdir.name 1476 data = [(1, ""), 1477 (1, "a"), 1478 (2, "bcdf")] 1479 self.sc.parallelize(data).saveAsNewAPIHadoopFile( 1480 basepath + "/newhadoop/", 1481 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", 1482 "org.apache.hadoop.io.IntWritable", 1483 "org.apache.hadoop.io.Text") 1484 result = sorted(self.sc.newAPIHadoopFile( 1485 basepath + "/newhadoop/", 1486 "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", 1487 "org.apache.hadoop.io.IntWritable", 1488 "org.apache.hadoop.io.Text").collect()) 1489 self.assertEqual(result, data) 1490 1491 conf = { 1492 "mapreduce.outputformat.class": 1493 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", 1494 "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", 1495 "mapred.output.value.class": "org.apache.hadoop.io.Text", 1496 "mapred.output.dir": basepath + "/newdataset/" 1497 } 1498 self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) 1499 input_conf = {"mapred.input.dir": basepath + "/newdataset/"} 1500 new_dataset = sorted(self.sc.newAPIHadoopRDD( 1501 "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", 1502 "org.apache.hadoop.io.IntWritable", 1503 "org.apache.hadoop.io.Text", 1504 conf=input_conf).collect()) 1505 self.assertEqual(new_dataset, data) 1506 1507 @unittest.skipIf(sys.version >= "3", "serialize of array") 1508 def test_newhadoop_with_array(self): 1509 basepath = self.tempdir.name 1510 # use custom ArrayWritable types and converters to handle arrays 1511 array_data = [(1, array('d')), 1512 (1, array('d', [1.0, 2.0, 3.0])), 1513 (2, array('d', [3.0, 4.0, 5.0]))] 1514 self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( 1515 basepath + "/newhadoop/", 1516 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", 1517 "org.apache.hadoop.io.IntWritable", 1518 "org.apache.spark.api.python.DoubleArrayWritable", 1519 valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") 1520 result = sorted(self.sc.newAPIHadoopFile( 1521 basepath + "/newhadoop/", 1522 "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", 1523 "org.apache.hadoop.io.IntWritable", 1524 "org.apache.spark.api.python.DoubleArrayWritable", 1525 valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) 1526 self.assertEqual(result, array_data) 1527 1528 conf = { 1529 "mapreduce.outputformat.class": 1530 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", 1531 "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", 1532 "mapred.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", 1533 "mapred.output.dir": basepath + "/newdataset/" 1534 } 1535 self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( 1536 conf, 1537 valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") 1538 input_conf = {"mapred.input.dir": basepath + "/newdataset/"} 1539 new_dataset = sorted(self.sc.newAPIHadoopRDD( 1540 "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", 1541 "org.apache.hadoop.io.IntWritable", 1542 "org.apache.spark.api.python.DoubleArrayWritable", 1543 valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", 1544 conf=input_conf).collect()) 1545 self.assertEqual(new_dataset, array_data) 1546 1547 def test_newolderror(self): 1548 basepath = self.tempdir.name 1549 rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) 1550 self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( 1551 basepath + "/newolderror/saveAsHadoopFile/", 1552 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) 1553 self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( 1554 basepath + "/newolderror/saveAsNewAPIHadoopFile/", 1555 "org.apache.hadoop.mapred.SequenceFileOutputFormat")) 1556 1557 def test_bad_inputs(self): 1558 basepath = self.tempdir.name 1559 rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) 1560 self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( 1561 basepath + "/badinputs/saveAsHadoopFile/", 1562 "org.apache.hadoop.mapred.NotValidOutputFormat")) 1563 self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( 1564 basepath + "/badinputs/saveAsNewAPIHadoopFile/", 1565 "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) 1566 1567 def test_converters(self): 1568 # use of custom converters 1569 basepath = self.tempdir.name 1570 data = [(1, {3.0: u'bb'}), 1571 (2, {1.0: u'aa'}), 1572 (3, {2.0: u'dd'})] 1573 self.sc.parallelize(data).saveAsNewAPIHadoopFile( 1574 basepath + "/converters/", 1575 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", 1576 keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", 1577 valueConverter="org.apache.spark.api.python.TestOutputValueConverter") 1578 converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) 1579 expected = [(u'1', 3.0), 1580 (u'2', 1.0), 1581 (u'3', 2.0)] 1582 self.assertEqual(converted, expected) 1583 1584 def test_reserialization(self): 1585 basepath = self.tempdir.name 1586 x = range(1, 5) 1587 y = range(1001, 1005) 1588 data = list(zip(x, y)) 1589 rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) 1590 rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") 1591 result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) 1592 self.assertEqual(result1, data) 1593 1594 rdd.saveAsHadoopFile( 1595 basepath + "/reserialize/hadoop", 1596 "org.apache.hadoop.mapred.SequenceFileOutputFormat") 1597 result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) 1598 self.assertEqual(result2, data) 1599 1600 rdd.saveAsNewAPIHadoopFile( 1601 basepath + "/reserialize/newhadoop", 1602 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") 1603 result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) 1604 self.assertEqual(result3, data) 1605 1606 conf4 = { 1607 "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", 1608 "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", 1609 "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", 1610 "mapred.output.dir": basepath + "/reserialize/dataset"} 1611 rdd.saveAsHadoopDataset(conf4) 1612 result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) 1613 self.assertEqual(result4, data) 1614 1615 conf5 = {"mapreduce.outputformat.class": 1616 "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", 1617 "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", 1618 "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", 1619 "mapred.output.dir": basepath + "/reserialize/newdataset"} 1620 rdd.saveAsNewAPIHadoopDataset(conf5) 1621 result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) 1622 self.assertEqual(result5, data) 1623 1624 def test_malformed_RDD(self): 1625 basepath = self.tempdir.name 1626 # non-batch-serialized RDD[[(K, V)]] should be rejected 1627 data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] 1628 rdd = self.sc.parallelize(data, len(data)) 1629 self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( 1630 basepath + "/malformed/sequence")) 1631 1632 1633class DaemonTests(unittest.TestCase): 1634 def connect(self, port): 1635 from socket import socket, AF_INET, SOCK_STREAM 1636 sock = socket(AF_INET, SOCK_STREAM) 1637 sock.connect(('127.0.0.1', port)) 1638 # send a split index of -1 to shutdown the worker 1639 sock.send(b"\xFF\xFF\xFF\xFF") 1640 sock.close() 1641 return True 1642 1643 def do_termination_test(self, terminator): 1644 from subprocess import Popen, PIPE 1645 from errno import ECONNREFUSED 1646 1647 # start daemon 1648 daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") 1649 python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") 1650 daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) 1651 1652 # read the port number 1653 port = read_int(daemon.stdout) 1654 1655 # daemon should accept connections 1656 self.assertTrue(self.connect(port)) 1657 1658 # request shutdown 1659 terminator(daemon) 1660 time.sleep(1) 1661 1662 # daemon should no longer accept connections 1663 try: 1664 self.connect(port) 1665 except EnvironmentError as exception: 1666 self.assertEqual(exception.errno, ECONNREFUSED) 1667 else: 1668 self.fail("Expected EnvironmentError to be raised") 1669 1670 def test_termination_stdin(self): 1671 """Ensure that daemon and workers terminate when stdin is closed.""" 1672 self.do_termination_test(lambda daemon: daemon.stdin.close()) 1673 1674 def test_termination_sigterm(self): 1675 """Ensure that daemon and workers terminate on SIGTERM.""" 1676 from signal import SIGTERM 1677 self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) 1678 1679 1680class WorkerTests(ReusedPySparkTestCase): 1681 def test_cancel_task(self): 1682 temp = tempfile.NamedTemporaryFile(delete=True) 1683 temp.close() 1684 path = temp.name 1685 1686 def sleep(x): 1687 import os 1688 import time 1689 with open(path, 'w') as f: 1690 f.write("%d %d" % (os.getppid(), os.getpid())) 1691 time.sleep(100) 1692 1693 # start job in background thread 1694 def run(): 1695 try: 1696 self.sc.parallelize(range(1), 1).foreach(sleep) 1697 except Exception: 1698 pass 1699 import threading 1700 t = threading.Thread(target=run) 1701 t.daemon = True 1702 t.start() 1703 1704 daemon_pid, worker_pid = 0, 0 1705 while True: 1706 if os.path.exists(path): 1707 with open(path) as f: 1708 data = f.read().split(' ') 1709 daemon_pid, worker_pid = map(int, data) 1710 break 1711 time.sleep(0.1) 1712 1713 # cancel jobs 1714 self.sc.cancelAllJobs() 1715 t.join() 1716 1717 for i in range(50): 1718 try: 1719 os.kill(worker_pid, 0) 1720 time.sleep(0.1) 1721 except OSError: 1722 break # worker was killed 1723 else: 1724 self.fail("worker has not been killed after 5 seconds") 1725 1726 try: 1727 os.kill(daemon_pid, 0) 1728 except OSError: 1729 self.fail("daemon had been killed") 1730 1731 # run a normal job 1732 rdd = self.sc.parallelize(xrange(100), 1) 1733 self.assertEqual(100, rdd.map(str).count()) 1734 1735 def test_after_exception(self): 1736 def raise_exception(_): 1737 raise Exception() 1738 rdd = self.sc.parallelize(xrange(100), 1) 1739 with QuietTest(self.sc): 1740 self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) 1741 self.assertEqual(100, rdd.map(str).count()) 1742 1743 def test_after_jvm_exception(self): 1744 tempFile = tempfile.NamedTemporaryFile(delete=False) 1745 tempFile.write(b"Hello World!") 1746 tempFile.close() 1747 data = self.sc.textFile(tempFile.name, 1) 1748 filtered_data = data.filter(lambda x: True) 1749 self.assertEqual(1, filtered_data.count()) 1750 os.unlink(tempFile.name) 1751 with QuietTest(self.sc): 1752 self.assertRaises(Exception, lambda: filtered_data.count()) 1753 1754 rdd = self.sc.parallelize(xrange(100), 1) 1755 self.assertEqual(100, rdd.map(str).count()) 1756 1757 def test_accumulator_when_reuse_worker(self): 1758 from pyspark.accumulators import INT_ACCUMULATOR_PARAM 1759 acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) 1760 self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) 1761 self.assertEqual(sum(range(100)), acc1.value) 1762 1763 acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) 1764 self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) 1765 self.assertEqual(sum(range(100)), acc2.value) 1766 self.assertEqual(sum(range(100)), acc1.value) 1767 1768 def test_reuse_worker_after_take(self): 1769 rdd = self.sc.parallelize(xrange(100000), 1) 1770 self.assertEqual(0, rdd.first()) 1771 1772 def count(): 1773 try: 1774 rdd.count() 1775 except Exception: 1776 pass 1777 1778 t = threading.Thread(target=count) 1779 t.daemon = True 1780 t.start() 1781 t.join(5) 1782 self.assertTrue(not t.isAlive()) 1783 self.assertEqual(100000, rdd.count()) 1784 1785 def test_with_different_versions_of_python(self): 1786 rdd = self.sc.parallelize(range(10)) 1787 rdd.count() 1788 version = self.sc.pythonVer 1789 self.sc.pythonVer = "2.0" 1790 try: 1791 with QuietTest(self.sc): 1792 self.assertRaises(Py4JJavaError, lambda: rdd.count()) 1793 finally: 1794 self.sc.pythonVer = version 1795 1796 1797class SparkSubmitTests(unittest.TestCase): 1798 1799 def setUp(self): 1800 self.programDir = tempfile.mkdtemp() 1801 self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") 1802 1803 def tearDown(self): 1804 shutil.rmtree(self.programDir) 1805 1806 def createTempFile(self, name, content, dir=None): 1807 """ 1808 Create a temp file with the given name and content and return its path. 1809 Strips leading spaces from content up to the first '|' in each line. 1810 """ 1811 pattern = re.compile(r'^ *\|', re.MULTILINE) 1812 content = re.sub(pattern, '', content.strip()) 1813 if dir is None: 1814 path = os.path.join(self.programDir, name) 1815 else: 1816 os.makedirs(os.path.join(self.programDir, dir)) 1817 path = os.path.join(self.programDir, dir, name) 1818 with open(path, "w") as f: 1819 f.write(content) 1820 return path 1821 1822 def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): 1823 """ 1824 Create a zip archive containing a file with the given content and return its path. 1825 Strips leading spaces from content up to the first '|' in each line. 1826 """ 1827 pattern = re.compile(r'^ *\|', re.MULTILINE) 1828 content = re.sub(pattern, '', content.strip()) 1829 if dir is None: 1830 path = os.path.join(self.programDir, name + ext) 1831 else: 1832 path = os.path.join(self.programDir, dir, zip_name + ext) 1833 zip = zipfile.ZipFile(path, 'w') 1834 zip.writestr(name, content) 1835 zip.close() 1836 return path 1837 1838 def create_spark_package(self, artifact_name): 1839 group_id, artifact_id, version = artifact_name.split(":") 1840 self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" 1841 |<?xml version="1.0" encoding="UTF-8"?> 1842 |<project xmlns="http://maven.apache.org/POM/4.0.0" 1843 | xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" 1844 | xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 1845 | http://maven.apache.org/xsd/maven-4.0.0.xsd"> 1846 | <modelVersion>4.0.0</modelVersion> 1847 | <groupId>%s</groupId> 1848 | <artifactId>%s</artifactId> 1849 | <version>%s</version> 1850 |</project> 1851 """ % (group_id, artifact_id, version)).lstrip(), 1852 os.path.join(group_id, artifact_id, version)) 1853 self.createFileInZip("%s.py" % artifact_id, """ 1854 |def myfunc(x): 1855 | return x + 1 1856 """, ".jar", os.path.join(group_id, artifact_id, version), 1857 "%s-%s" % (artifact_id, version)) 1858 1859 def test_single_script(self): 1860 """Submit and test a single script file""" 1861 script = self.createTempFile("test.py", """ 1862 |from pyspark import SparkContext 1863 | 1864 |sc = SparkContext() 1865 |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) 1866 """) 1867 proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) 1868 out, err = proc.communicate() 1869 self.assertEqual(0, proc.returncode) 1870 self.assertIn("[2, 4, 6]", out.decode('utf-8')) 1871 1872 def test_script_with_local_functions(self): 1873 """Submit and test a single script file calling a global function""" 1874 script = self.createTempFile("test.py", """ 1875 |from pyspark import SparkContext 1876 | 1877 |def foo(x): 1878 | return x * 3 1879 | 1880 |sc = SparkContext() 1881 |print(sc.parallelize([1, 2, 3]).map(foo).collect()) 1882 """) 1883 proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) 1884 out, err = proc.communicate() 1885 self.assertEqual(0, proc.returncode) 1886 self.assertIn("[3, 6, 9]", out.decode('utf-8')) 1887 1888 def test_module_dependency(self): 1889 """Submit and test a script with a dependency on another module""" 1890 script = self.createTempFile("test.py", """ 1891 |from pyspark import SparkContext 1892 |from mylib import myfunc 1893 | 1894 |sc = SparkContext() 1895 |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) 1896 """) 1897 zip = self.createFileInZip("mylib.py", """ 1898 |def myfunc(x): 1899 | return x + 1 1900 """) 1901 proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], 1902 stdout=subprocess.PIPE) 1903 out, err = proc.communicate() 1904 self.assertEqual(0, proc.returncode) 1905 self.assertIn("[2, 3, 4]", out.decode('utf-8')) 1906 1907 def test_module_dependency_on_cluster(self): 1908 """Submit and test a script with a dependency on another module on a cluster""" 1909 script = self.createTempFile("test.py", """ 1910 |from pyspark import SparkContext 1911 |from mylib import myfunc 1912 | 1913 |sc = SparkContext() 1914 |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) 1915 """) 1916 zip = self.createFileInZip("mylib.py", """ 1917 |def myfunc(x): 1918 | return x + 1 1919 """) 1920 proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", 1921 "local-cluster[1,1,1024]", script], 1922 stdout=subprocess.PIPE) 1923 out, err = proc.communicate() 1924 self.assertEqual(0, proc.returncode) 1925 self.assertIn("[2, 3, 4]", out.decode('utf-8')) 1926 1927 def test_package_dependency(self): 1928 """Submit and test a script with a dependency on a Spark Package""" 1929 script = self.createTempFile("test.py", """ 1930 |from pyspark import SparkContext 1931 |from mylib import myfunc 1932 | 1933 |sc = SparkContext() 1934 |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) 1935 """) 1936 self.create_spark_package("a:mylib:0.1") 1937 proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", 1938 "file:" + self.programDir, script], stdout=subprocess.PIPE) 1939 out, err = proc.communicate() 1940 self.assertEqual(0, proc.returncode) 1941 self.assertIn("[2, 3, 4]", out.decode('utf-8')) 1942 1943 def test_package_dependency_on_cluster(self): 1944 """Submit and test a script with a dependency on a Spark Package on a cluster""" 1945 script = self.createTempFile("test.py", """ 1946 |from pyspark import SparkContext 1947 |from mylib import myfunc 1948 | 1949 |sc = SparkContext() 1950 |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) 1951 """) 1952 self.create_spark_package("a:mylib:0.1") 1953 proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", 1954 "file:" + self.programDir, "--master", 1955 "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) 1956 out, err = proc.communicate() 1957 self.assertEqual(0, proc.returncode) 1958 self.assertIn("[2, 3, 4]", out.decode('utf-8')) 1959 1960 def test_single_script_on_cluster(self): 1961 """Submit and test a single script on a cluster""" 1962 script = self.createTempFile("test.py", """ 1963 |from pyspark import SparkContext 1964 | 1965 |def foo(x): 1966 | return x * 2 1967 | 1968 |sc = SparkContext() 1969 |print(sc.parallelize([1, 2, 3]).map(foo).collect()) 1970 """) 1971 # this will fail if you have different spark.executor.memory 1972 # in conf/spark-defaults.conf 1973 proc = subprocess.Popen( 1974 [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], 1975 stdout=subprocess.PIPE) 1976 out, err = proc.communicate() 1977 self.assertEqual(0, proc.returncode) 1978 self.assertIn("[2, 4, 6]", out.decode('utf-8')) 1979 1980 def test_user_configuration(self): 1981 """Make sure user configuration is respected (SPARK-19307)""" 1982 script = self.createTempFile("test.py", """ 1983 |from pyspark import SparkConf, SparkContext 1984 | 1985 |conf = SparkConf().set("spark.test_config", "1") 1986 |sc = SparkContext(conf = conf) 1987 |try: 1988 | if sc._conf.get("spark.test_config") != "1": 1989 | raise Exception("Cannot find spark.test_config in SparkContext's conf.") 1990 |finally: 1991 | sc.stop() 1992 """) 1993 proc = subprocess.Popen( 1994 [self.sparkSubmit, "--master", "local", script], 1995 stdout=subprocess.PIPE, 1996 stderr=subprocess.STDOUT) 1997 out, err = proc.communicate() 1998 self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) 1999 2000 2001class ContextTests(unittest.TestCase): 2002 2003 def test_failed_sparkcontext_creation(self): 2004 # Regression test for SPARK-1550 2005 self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) 2006 2007 def test_get_or_create(self): 2008 with SparkContext.getOrCreate() as sc: 2009 self.assertTrue(SparkContext.getOrCreate() is sc) 2010 2011 def test_parallelize_eager_cleanup(self): 2012 with SparkContext() as sc: 2013 temp_files = os.listdir(sc._temp_dir) 2014 rdd = sc.parallelize([0, 1, 2]) 2015 post_parallalize_temp_files = os.listdir(sc._temp_dir) 2016 self.assertEqual(temp_files, post_parallalize_temp_files) 2017 2018 def test_set_conf(self): 2019 # This is for an internal use case. When there is an existing SparkContext, 2020 # SparkSession's builder needs to set configs into SparkContext's conf. 2021 sc = SparkContext() 2022 sc._conf.set("spark.test.SPARK16224", "SPARK16224") 2023 self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") 2024 sc.stop() 2025 2026 def test_stop(self): 2027 sc = SparkContext() 2028 self.assertNotEqual(SparkContext._active_spark_context, None) 2029 sc.stop() 2030 self.assertEqual(SparkContext._active_spark_context, None) 2031 2032 def test_with(self): 2033 with SparkContext() as sc: 2034 self.assertNotEqual(SparkContext._active_spark_context, None) 2035 self.assertEqual(SparkContext._active_spark_context, None) 2036 2037 def test_with_exception(self): 2038 try: 2039 with SparkContext() as sc: 2040 self.assertNotEqual(SparkContext._active_spark_context, None) 2041 raise Exception() 2042 except: 2043 pass 2044 self.assertEqual(SparkContext._active_spark_context, None) 2045 2046 def test_with_stop(self): 2047 with SparkContext() as sc: 2048 self.assertNotEqual(SparkContext._active_spark_context, None) 2049 sc.stop() 2050 self.assertEqual(SparkContext._active_spark_context, None) 2051 2052 def test_progress_api(self): 2053 with SparkContext() as sc: 2054 sc.setJobGroup('test_progress_api', '', True) 2055 rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) 2056 2057 def run(): 2058 try: 2059 rdd.count() 2060 except Exception: 2061 pass 2062 t = threading.Thread(target=run) 2063 t.daemon = True 2064 t.start() 2065 # wait for scheduler to start 2066 time.sleep(1) 2067 2068 tracker = sc.statusTracker() 2069 jobIds = tracker.getJobIdsForGroup('test_progress_api') 2070 self.assertEqual(1, len(jobIds)) 2071 job = tracker.getJobInfo(jobIds[0]) 2072 self.assertEqual(1, len(job.stageIds)) 2073 stage = tracker.getStageInfo(job.stageIds[0]) 2074 self.assertEqual(rdd.getNumPartitions(), stage.numTasks) 2075 2076 sc.cancelAllJobs() 2077 t.join() 2078 # wait for event listener to update the status 2079 time.sleep(1) 2080 2081 job = tracker.getJobInfo(jobIds[0]) 2082 self.assertEqual('FAILED', job.status) 2083 self.assertEqual([], tracker.getActiveJobsIds()) 2084 self.assertEqual([], tracker.getActiveStageIds()) 2085 2086 sc.stop() 2087 2088 def test_startTime(self): 2089 with SparkContext() as sc: 2090 self.assertGreater(sc.startTime, 0) 2091 2092 2093class ConfTests(unittest.TestCase): 2094 def test_memory_conf(self): 2095 memoryList = ["1T", "1G", "1M", "1024K"] 2096 for memory in memoryList: 2097 sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) 2098 l = list(range(1024)) 2099 random.shuffle(l) 2100 rdd = sc.parallelize(l, 4) 2101 self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) 2102 sc.stop() 2103 2104 2105class KeywordOnlyTests(unittest.TestCase): 2106 class Wrapped(object): 2107 @keyword_only 2108 def set(self, x=None, y=None): 2109 if "x" in self._input_kwargs: 2110 self._x = self._input_kwargs["x"] 2111 if "y" in self._input_kwargs: 2112 self._y = self._input_kwargs["y"] 2113 return x, y 2114 2115 def test_keywords(self): 2116 w = self.Wrapped() 2117 x, y = w.set(y=1) 2118 self.assertEqual(y, 1) 2119 self.assertEqual(y, w._y) 2120 self.assertIsNone(x) 2121 self.assertFalse(hasattr(w, "_x")) 2122 2123 def test_non_keywords(self): 2124 w = self.Wrapped() 2125 self.assertRaises(TypeError, lambda: w.set(0, y=1)) 2126 2127 def test_kwarg_ownership(self): 2128 # test _input_kwargs is owned by each class instance and not a shared static variable 2129 class Setter(object): 2130 @keyword_only 2131 def set(self, x=None, other=None, other_x=None): 2132 if "other" in self._input_kwargs: 2133 self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) 2134 self._x = self._input_kwargs["x"] 2135 2136 a = Setter() 2137 b = Setter() 2138 a.set(x=1, other=b, other_x=2) 2139 self.assertEqual(a._x, 1) 2140 self.assertEqual(b._x, 2) 2141 2142 2143@unittest.skipIf(not _have_scipy, "SciPy not installed") 2144class SciPyTests(PySparkTestCase): 2145 2146 """General PySpark tests that depend on scipy """ 2147 2148 def test_serialize(self): 2149 from scipy.special import gammaln 2150 x = range(1, 5) 2151 expected = list(map(gammaln, x)) 2152 observed = self.sc.parallelize(x).map(gammaln).collect() 2153 self.assertEqual(expected, observed) 2154 2155 2156@unittest.skipIf(not _have_numpy, "NumPy not installed") 2157class NumPyTests(PySparkTestCase): 2158 2159 """General PySpark tests that depend on numpy """ 2160 2161 def test_statcounter_array(self): 2162 x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) 2163 s = x.stats() 2164 self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) 2165 self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) 2166 self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) 2167 self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) 2168 2169 stats_dict = s.asDict() 2170 self.assertEqual(3, stats_dict['count']) 2171 self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) 2172 self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) 2173 self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) 2174 self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) 2175 self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) 2176 self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) 2177 2178 stats_sample_dict = s.asDict(sample=True) 2179 self.assertEqual(3, stats_dict['count']) 2180 self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) 2181 self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) 2182 self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) 2183 self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) 2184 self.assertSequenceEqual( 2185 [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) 2186 self.assertSequenceEqual( 2187 [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) 2188 2189 2190if __name__ == "__main__": 2191 from pyspark.tests import * 2192 if not _have_scipy: 2193 print("NOTE: Skipping SciPy tests as it does not seem to be installed") 2194 if not _have_numpy: 2195 print("NOTE: Skipping NumPy tests as it does not seem to be installed") 2196 if xmlrunner: 2197 unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) 2198 else: 2199 unittest.main() 2200 if not _have_scipy: 2201 print("NOTE: SciPy tests were skipped as it does not seem to be installed") 2202 if not _have_numpy: 2203 print("NOTE: NumPy tests were skipped as it does not seem to be installed") 2204