1""" 2Tests specific to the collections module. 3""" 4from __future__ import absolute_import, division, print_function 5 6import io 7 8import numpy as np 9from numpy.testing import ( 10 assert_array_equal, assert_array_almost_equal, assert_equal) 11import pytest 12 13import matplotlib.pyplot as plt 14import matplotlib.collections as mcollections 15import matplotlib.transforms as mtransforms 16from matplotlib.collections import Collection, LineCollection, EventCollection 17from matplotlib.testing.decorators import image_comparison 18 19 20def generate_EventCollection_plot(): 21 ''' 22 generate the initial collection and plot it 23 ''' 24 positions = np.array([0., 1., 2., 3., 5., 8., 13., 21.]) 25 extra_positions = np.array([34., 55., 89.]) 26 orientation = 'horizontal' 27 lineoffset = 1 28 linelength = .5 29 linewidth = 2 30 color = [1, 0, 0, 1] 31 linestyle = 'solid' 32 antialiased = True 33 34 coll = EventCollection(positions, 35 orientation=orientation, 36 lineoffset=lineoffset, 37 linelength=linelength, 38 linewidth=linewidth, 39 color=color, 40 linestyle=linestyle, 41 antialiased=antialiased 42 ) 43 44 fig = plt.figure() 45 splt = fig.add_subplot(1, 1, 1) 46 splt.add_collection(coll) 47 splt.set_title('EventCollection: default') 48 props = {'positions': positions, 49 'extra_positions': extra_positions, 50 'orientation': orientation, 51 'lineoffset': lineoffset, 52 'linelength': linelength, 53 'linewidth': linewidth, 54 'color': color, 55 'linestyle': linestyle, 56 'antialiased': antialiased 57 } 58 splt.set_xlim(-1, 22) 59 splt.set_ylim(0, 2) 60 return splt, coll, props 61 62 63@image_comparison(baseline_images=['EventCollection_plot__default']) 64def test__EventCollection__get_segments(): 65 ''' 66 check to make sure the default segments have the correct coordinates 67 ''' 68 _, coll, props = generate_EventCollection_plot() 69 check_segments(coll, 70 props['positions'], 71 props['linelength'], 72 props['lineoffset'], 73 props['orientation']) 74 75 76def test__EventCollection__get_positions(): 77 ''' 78 check to make sure the default positions match the input positions 79 ''' 80 _, coll, props = generate_EventCollection_plot() 81 np.testing.assert_array_equal(props['positions'], coll.get_positions()) 82 83 84def test__EventCollection__get_orientation(): 85 ''' 86 check to make sure the default orientation matches the input 87 orientation 88 ''' 89 _, coll, props = generate_EventCollection_plot() 90 assert_equal(props['orientation'], coll.get_orientation()) 91 92 93def test__EventCollection__is_horizontal(): 94 ''' 95 check to make sure the default orientation matches the input 96 orientation 97 ''' 98 _, coll, _ = generate_EventCollection_plot() 99 assert_equal(True, coll.is_horizontal()) 100 101 102def test__EventCollection__get_linelength(): 103 ''' 104 check to make sure the default linelength matches the input linelength 105 ''' 106 _, coll, props = generate_EventCollection_plot() 107 assert_equal(props['linelength'], coll.get_linelength()) 108 109 110def test__EventCollection__get_lineoffset(): 111 ''' 112 check to make sure the default lineoffset matches the input lineoffset 113 ''' 114 _, coll, props = generate_EventCollection_plot() 115 assert_equal(props['lineoffset'], coll.get_lineoffset()) 116 117 118def test__EventCollection__get_linestyle(): 119 ''' 120 check to make sure the default linestyle matches the input linestyle 121 ''' 122 _, coll, _ = generate_EventCollection_plot() 123 assert_equal(coll.get_linestyle(), [(None, None)]) 124 125 126def test__EventCollection__get_color(): 127 ''' 128 check to make sure the default color matches the input color 129 ''' 130 _, coll, props = generate_EventCollection_plot() 131 np.testing.assert_array_equal(props['color'], coll.get_color()) 132 check_allprop_array(coll.get_colors(), props['color']) 133 134 135@image_comparison(baseline_images=['EventCollection_plot__set_positions']) 136def test__EventCollection__set_positions(): 137 ''' 138 check to make sure set_positions works properly 139 ''' 140 splt, coll, props = generate_EventCollection_plot() 141 new_positions = np.hstack([props['positions'], props['extra_positions']]) 142 coll.set_positions(new_positions) 143 np.testing.assert_array_equal(new_positions, coll.get_positions()) 144 check_segments(coll, new_positions, 145 props['linelength'], 146 props['lineoffset'], 147 props['orientation']) 148 splt.set_title('EventCollection: set_positions') 149 splt.set_xlim(-1, 90) 150 151 152@image_comparison(baseline_images=['EventCollection_plot__add_positions']) 153def test__EventCollection__add_positions(): 154 ''' 155 check to make sure add_positions works properly 156 ''' 157 splt, coll, props = generate_EventCollection_plot() 158 new_positions = np.hstack([props['positions'], 159 props['extra_positions'][0]]) 160 coll.add_positions(props['extra_positions'][0]) 161 np.testing.assert_array_equal(new_positions, coll.get_positions()) 162 check_segments(coll, 163 new_positions, 164 props['linelength'], 165 props['lineoffset'], 166 props['orientation']) 167 splt.set_title('EventCollection: add_positions') 168 splt.set_xlim(-1, 35) 169 170 171@image_comparison(baseline_images=['EventCollection_plot__append_positions']) 172def test__EventCollection__append_positions(): 173 ''' 174 check to make sure append_positions works properly 175 ''' 176 splt, coll, props = generate_EventCollection_plot() 177 new_positions = np.hstack([props['positions'], 178 props['extra_positions'][2]]) 179 coll.append_positions(props['extra_positions'][2]) 180 np.testing.assert_array_equal(new_positions, coll.get_positions()) 181 check_segments(coll, 182 new_positions, 183 props['linelength'], 184 props['lineoffset'], 185 props['orientation']) 186 splt.set_title('EventCollection: append_positions') 187 splt.set_xlim(-1, 90) 188 189 190@image_comparison(baseline_images=['EventCollection_plot__extend_positions']) 191def test__EventCollection__extend_positions(): 192 ''' 193 check to make sure extend_positions works properly 194 ''' 195 splt, coll, props = generate_EventCollection_plot() 196 new_positions = np.hstack([props['positions'], 197 props['extra_positions'][1:]]) 198 coll.extend_positions(props['extra_positions'][1:]) 199 np.testing.assert_array_equal(new_positions, coll.get_positions()) 200 check_segments(coll, 201 new_positions, 202 props['linelength'], 203 props['lineoffset'], 204 props['orientation']) 205 splt.set_title('EventCollection: extend_positions') 206 splt.set_xlim(-1, 90) 207 208 209@image_comparison(baseline_images=['EventCollection_plot__switch_orientation']) 210def test__EventCollection__switch_orientation(): 211 ''' 212 check to make sure switch_orientation works properly 213 ''' 214 splt, coll, props = generate_EventCollection_plot() 215 new_orientation = 'vertical' 216 coll.switch_orientation() 217 assert_equal(new_orientation, coll.get_orientation()) 218 assert_equal(False, coll.is_horizontal()) 219 new_positions = coll.get_positions() 220 check_segments(coll, 221 new_positions, 222 props['linelength'], 223 props['lineoffset'], new_orientation) 224 splt.set_title('EventCollection: switch_orientation') 225 splt.set_ylim(-1, 22) 226 splt.set_xlim(0, 2) 227 228 229@image_comparison( 230 baseline_images=['EventCollection_plot__switch_orientation__2x']) 231def test__EventCollection__switch_orientation_2x(): 232 ''' 233 check to make sure calling switch_orientation twice sets the 234 orientation back to the default 235 ''' 236 splt, coll, props = generate_EventCollection_plot() 237 coll.switch_orientation() 238 coll.switch_orientation() 239 new_positions = coll.get_positions() 240 assert_equal(props['orientation'], coll.get_orientation()) 241 assert_equal(True, coll.is_horizontal()) 242 np.testing.assert_array_equal(props['positions'], new_positions) 243 check_segments(coll, 244 new_positions, 245 props['linelength'], 246 props['lineoffset'], 247 props['orientation']) 248 splt.set_title('EventCollection: switch_orientation 2x') 249 250 251@image_comparison(baseline_images=['EventCollection_plot__set_orientation']) 252def test__EventCollection__set_orientation(): 253 ''' 254 check to make sure set_orientation works properly 255 ''' 256 splt, coll, props = generate_EventCollection_plot() 257 new_orientation = 'vertical' 258 coll.set_orientation(new_orientation) 259 assert_equal(new_orientation, coll.get_orientation()) 260 assert_equal(False, coll.is_horizontal()) 261 check_segments(coll, 262 props['positions'], 263 props['linelength'], 264 props['lineoffset'], 265 new_orientation) 266 splt.set_title('EventCollection: set_orientation') 267 splt.set_ylim(-1, 22) 268 splt.set_xlim(0, 2) 269 270 271@image_comparison(baseline_images=['EventCollection_plot__set_linelength']) 272def test__EventCollection__set_linelength(): 273 ''' 274 check to make sure set_linelength works properly 275 ''' 276 splt, coll, props = generate_EventCollection_plot() 277 new_linelength = 15 278 coll.set_linelength(new_linelength) 279 assert_equal(new_linelength, coll.get_linelength()) 280 check_segments(coll, 281 props['positions'], 282 new_linelength, 283 props['lineoffset'], 284 props['orientation']) 285 splt.set_title('EventCollection: set_linelength') 286 splt.set_ylim(-20, 20) 287 288 289@image_comparison(baseline_images=['EventCollection_plot__set_lineoffset']) 290def test__EventCollection__set_lineoffset(): 291 ''' 292 check to make sure set_lineoffset works properly 293 ''' 294 splt, coll, props = generate_EventCollection_plot() 295 new_lineoffset = -5. 296 coll.set_lineoffset(new_lineoffset) 297 assert_equal(new_lineoffset, coll.get_lineoffset()) 298 check_segments(coll, 299 props['positions'], 300 props['linelength'], 301 new_lineoffset, 302 props['orientation']) 303 splt.set_title('EventCollection: set_lineoffset') 304 splt.set_ylim(-6, -4) 305 306 307@image_comparison(baseline_images=['EventCollection_plot__set_linestyle']) 308def test__EventCollection__set_linestyle(): 309 ''' 310 check to make sure set_linestyle works properly 311 ''' 312 splt, coll, _ = generate_EventCollection_plot() 313 new_linestyle = 'dashed' 314 coll.set_linestyle(new_linestyle) 315 assert_equal(coll.get_linestyle(), [(0, (6.0, 6.0))]) 316 splt.set_title('EventCollection: set_linestyle') 317 318 319@image_comparison(baseline_images=['EventCollection_plot__set_ls_dash'], 320 remove_text=True) 321def test__EventCollection__set_linestyle_single_dash(): 322 ''' 323 check to make sure set_linestyle accepts a single dash pattern 324 ''' 325 splt, coll, _ = generate_EventCollection_plot() 326 new_linestyle = (0, (6., 6.)) 327 coll.set_linestyle(new_linestyle) 328 assert_equal(coll.get_linestyle(), [(0, (6.0, 6.0))]) 329 splt.set_title('EventCollection: set_linestyle') 330 331 332@image_comparison(baseline_images=['EventCollection_plot__set_linewidth']) 333def test__EventCollection__set_linewidth(): 334 ''' 335 check to make sure set_linestyle works properly 336 ''' 337 splt, coll, _ = generate_EventCollection_plot() 338 new_linewidth = 5 339 coll.set_linewidth(new_linewidth) 340 assert_equal(coll.get_linewidth(), new_linewidth) 341 splt.set_title('EventCollection: set_linewidth') 342 343 344@image_comparison(baseline_images=['EventCollection_plot__set_color']) 345def test__EventCollection__set_color(): 346 ''' 347 check to make sure set_color works properly 348 ''' 349 splt, coll, _ = generate_EventCollection_plot() 350 new_color = np.array([0, 1, 1, 1]) 351 coll.set_color(new_color) 352 np.testing.assert_array_equal(new_color, coll.get_color()) 353 check_allprop_array(coll.get_colors(), new_color) 354 splt.set_title('EventCollection: set_color') 355 356 357def check_segments(coll, positions, linelength, lineoffset, orientation): 358 ''' 359 check to make sure all values in the segment are correct, given a 360 particular set of inputs 361 362 note: this is not a test, it is used by tests 363 ''' 364 segments = coll.get_segments() 365 if (orientation.lower() == 'horizontal' 366 or orientation.lower() == 'none' or orientation is None): 367 # if horizontal, the position in is in the y-axis 368 pos1 = 1 369 pos2 = 0 370 elif orientation.lower() == 'vertical': 371 # if vertical, the position in is in the x-axis 372 pos1 = 0 373 pos2 = 1 374 else: 375 raise ValueError("orientation must be 'horizontal' or 'vertical'") 376 377 # test to make sure each segment is correct 378 for i, segment in enumerate(segments): 379 assert_equal(segment[0, pos1], lineoffset + linelength / 2.) 380 assert_equal(segment[1, pos1], lineoffset - linelength / 2.) 381 assert_equal(segment[0, pos2], positions[i]) 382 assert_equal(segment[1, pos2], positions[i]) 383 384 385def check_allprop_array(values, target): 386 ''' 387 check to make sure all values match the given target if arrays 388 389 note: this is not a test, it is used by tests 390 ''' 391 for value in values: 392 np.testing.assert_array_equal(value, target) 393 394 395def test_null_collection_datalim(): 396 col = mcollections.PathCollection([]) 397 col_data_lim = col.get_datalim(mtransforms.IdentityTransform()) 398 assert_array_equal(col_data_lim.get_points(), 399 mtransforms.Bbox.null().get_points()) 400 401 402def test_add_collection(): 403 # Test if data limits are unchanged by adding an empty collection. 404 # Github issue #1490, pull #1497. 405 plt.figure() 406 ax = plt.axes() 407 coll = ax.scatter([0, 1], [0, 1]) 408 ax.add_collection(coll) 409 bounds = ax.dataLim.bounds 410 coll = ax.scatter([], []) 411 assert_equal(ax.dataLim.bounds, bounds) 412 413 414def test_quiver_limits(): 415 ax = plt.axes() 416 x, y = np.arange(8), np.arange(10) 417 u = v = np.linspace(0, 10, 80).reshape(10, 8) 418 q = plt.quiver(x, y, u, v) 419 assert_equal(q.get_datalim(ax.transData).bounds, (0., 0., 7., 9.)) 420 421 plt.figure() 422 ax = plt.axes() 423 x = np.linspace(-5, 10, 20) 424 y = np.linspace(-2, 4, 10) 425 y, x = np.meshgrid(y, x) 426 trans = mtransforms.Affine2D().translate(25, 32) + ax.transData 427 plt.quiver(x, y, np.sin(x), np.cos(y), transform=trans) 428 assert_equal(ax.dataLim.bounds, (20.0, 30.0, 15.0, 6.0)) 429 430 431def test_barb_limits(): 432 ax = plt.axes() 433 x = np.linspace(-5, 10, 20) 434 y = np.linspace(-2, 4, 10) 435 y, x = np.meshgrid(y, x) 436 trans = mtransforms.Affine2D().translate(25, 32) + ax.transData 437 plt.barbs(x, y, np.sin(x), np.cos(y), transform=trans) 438 # The calculated bounds are approximately the bounds of the original data, 439 # this is because the entire path is taken into account when updating the 440 # datalim. 441 assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6), 442 decimal=1) 443 444 445@image_comparison(baseline_images=['EllipseCollection_test_image'], 446 extensions=['png'], 447 remove_text=True) 448def test_EllipseCollection(): 449 # Test basic functionality 450 fig, ax = plt.subplots() 451 x = np.arange(4) 452 y = np.arange(3) 453 X, Y = np.meshgrid(x, y) 454 XY = np.vstack((X.ravel(), Y.ravel())).T 455 456 ww = X / x[-1] 457 hh = Y / y[-1] 458 aa = np.ones_like(ww) * 20 # first axis is 20 degrees CCW from x axis 459 460 ec = mcollections.EllipseCollection(ww, hh, aa, 461 units='x', 462 offsets=XY, 463 transOffset=ax.transData, 464 facecolors='none') 465 ax.add_collection(ec) 466 ax.autoscale_view() 467 468 469@image_comparison(baseline_images=['polycollection_close'], 470 extensions=['png'], remove_text=True) 471def test_polycollection_close(): 472 from mpl_toolkits.mplot3d import Axes3D 473 474 vertsQuad = [ 475 [[0., 0.], [0., 1.], [1., 1.], [1., 0.]], 476 [[0., 1.], [2., 3.], [2., 2.], [1., 1.]], 477 [[2., 2.], [2., 3.], [4., 1.], [3., 1.]], 478 [[3., 0.], [3., 1.], [4., 1.], [4., 0.]]] 479 480 fig = plt.figure() 481 ax = Axes3D(fig) 482 483 colors = ['r', 'g', 'b', 'y', 'k'] 484 zpos = list(range(5)) 485 486 poly = mcollections.PolyCollection( 487 vertsQuad * len(zpos), linewidth=0.25) 488 poly.set_alpha(0.7) 489 490 # need to have a z-value for *each* polygon = element! 491 zs = [] 492 cs = [] 493 for z, c in zip(zpos, colors): 494 zs.extend([z] * len(vertsQuad)) 495 cs.extend([c] * len(vertsQuad)) 496 497 poly.set_color(cs) 498 499 ax.add_collection3d(poly, zs=zs, zdir='y') 500 501 # axis limit settings: 502 ax.set_xlim3d(0, 4) 503 ax.set_zlim3d(0, 3) 504 ax.set_ylim3d(0, 4) 505 506 507@image_comparison(baseline_images=['regularpolycollection_rotate'], 508 extensions=['png'], remove_text=True) 509def test_regularpolycollection_rotate(): 510 xx, yy = np.mgrid[:10, :10] 511 xy_points = np.transpose([xx.flatten(), yy.flatten()]) 512 rotations = np.linspace(0, 2*np.pi, len(xy_points)) 513 514 fig, ax = plt.subplots() 515 for xy, alpha in zip(xy_points, rotations): 516 col = mcollections.RegularPolyCollection( 517 4, sizes=(100,), rotation=alpha, 518 offsets=[xy], transOffset=ax.transData) 519 ax.add_collection(col, autolim=True) 520 ax.autoscale_view() 521 522 523@image_comparison(baseline_images=['regularpolycollection_scale'], 524 extensions=['png'], remove_text=True) 525def test_regularpolycollection_scale(): 526 # See issue #3860 527 528 class SquareCollection(mcollections.RegularPolyCollection): 529 def __init__(self, **kwargs): 530 super(SquareCollection, self).__init__( 531 4, rotation=np.pi/4., **kwargs) 532 533 def get_transform(self): 534 """Return transform scaling circle areas to data space.""" 535 ax = self.axes 536 537 pts2pixels = 72.0 / ax.figure.dpi 538 539 scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width 540 scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height 541 return mtransforms.Affine2D().scale(scale_x, scale_y) 542 543 fig, ax = plt.subplots() 544 545 xy = [(0, 0)] 546 # Unit square has a half-diagonal of `1 / sqrt(2)`, so `pi * r**2` 547 # equals... 548 circle_areas = [np.pi / 2] 549 squares = SquareCollection(sizes=circle_areas, offsets=xy, 550 transOffset=ax.transData) 551 ax.add_collection(squares, autolim=True) 552 ax.axis([-1, 1, -1, 1]) 553 554 555def test_picking(): 556 fig, ax = plt.subplots() 557 col = ax.scatter([0], [0], [1000], picker=True) 558 fig.savefig(io.BytesIO(), dpi=fig.dpi) 559 560 class MouseEvent(object): 561 pass 562 event = MouseEvent() 563 event.x = 325 564 event.y = 240 565 566 found, indices = col.contains(event) 567 assert found 568 assert_array_equal(indices['ind'], [0]) 569 570 571def test_linestyle_single_dashes(): 572 plt.scatter([0, 1, 2], [0, 1, 2], linestyle=(0., [2., 2.])) 573 plt.draw() 574 575 576@image_comparison(baseline_images=['size_in_xy'], remove_text=True, 577 extensions=['png']) 578def test_size_in_xy(): 579 fig, ax = plt.subplots() 580 581 widths, heights, angles = (10, 10), 10, 0 582 widths = 10, 10 583 coords = [(10, 10), (15, 15)] 584 e = mcollections.EllipseCollection( 585 widths, heights, angles, 586 units='xy', 587 offsets=coords, 588 transOffset=ax.transData) 589 590 ax.add_collection(e) 591 592 ax.set_xlim(0, 30) 593 ax.set_ylim(0, 30) 594 595 596def test_pandas_indexing(pd): 597 598 # Should not fail break when faced with a 599 # non-zero indexed series 600 index = [11, 12, 13] 601 ec = fc = pd.Series(['red', 'blue', 'green'], index=index) 602 lw = pd.Series([1, 2, 3], index=index) 603 ls = pd.Series(['solid', 'dashed', 'dashdot'], index=index) 604 aa = pd.Series([True, False, True], index=index) 605 606 Collection(edgecolors=ec) 607 Collection(facecolors=fc) 608 Collection(linewidths=lw) 609 Collection(linestyles=ls) 610 Collection(antialiaseds=aa) 611 612 613@pytest.mark.style('default') 614def test_lslw_bcast(): 615 col = mcollections.PathCollection([]) 616 col.set_linestyles(['-', '-']) 617 col.set_linewidths([1, 2, 3]) 618 619 assert_equal(col.get_linestyles(), [(None, None)] * 6) 620 assert_equal(col.get_linewidths(), [1, 2, 3] * 2) 621 622 col.set_linestyles(['-', '-', '-']) 623 assert_equal(col.get_linestyles(), [(None, None)] * 3) 624 assert_equal(col.get_linewidths(), [1, 2, 3]) 625 626 627@pytest.mark.style('default') 628def test_capstyle(): 629 col = mcollections.PathCollection([], capstyle='round') 630 assert_equal(col.get_capstyle(), 'round') 631 col.set_capstyle('butt') 632 assert_equal(col.get_capstyle(), 'butt') 633 634 635@pytest.mark.style('default') 636def test_joinstyle(): 637 col = mcollections.PathCollection([], joinstyle='round') 638 assert_equal(col.get_joinstyle(), 'round') 639 col.set_joinstyle('miter') 640 assert_equal(col.get_joinstyle(), 'miter') 641 642 643@image_comparison(baseline_images=['cap_and_joinstyle'], 644 extensions=['png']) 645def test_cap_and_joinstyle_image(): 646 fig = plt.figure() 647 ax = fig.add_subplot(1, 1, 1) 648 ax.set_xlim([-0.5, 1.5]) 649 ax.set_ylim([-0.5, 2.5]) 650 651 x = np.array([0.0, 1.0, 0.5]) 652 ys = np.array([[0.0], [0.5], [1.0]]) + np.array([[0.0, 0.0, 1.0]]) 653 654 segs = np.zeros((3, 3, 2)) 655 segs[:, :, 0] = x 656 segs[:, :, 1] = ys 657 line_segments = LineCollection(segs, linewidth=[10, 15, 20]) 658 line_segments.set_capstyle("round") 659 line_segments.set_joinstyle("miter") 660 661 ax.add_collection(line_segments) 662 ax.set_title('Line collection with customized caps and joinstyle') 663 664 665@image_comparison(baseline_images=['scatter_post_alpha'], 666 extensions=['png'], remove_text=True, 667 style='default') 668def test_scatter_post_alpha(): 669 fig, ax = plt.subplots() 670 sc = ax.scatter(range(5), range(5), c=range(5)) 671 # this needs to be here to update internal state 672 fig.canvas.draw() 673 sc.set_alpha(.1) 674