1"""
2Tests specific to the collections module.
3"""
4from __future__ import absolute_import, division, print_function
5
6import io
7
8import numpy as np
9from numpy.testing import (
10    assert_array_equal, assert_array_almost_equal, assert_equal)
11import pytest
12
13import matplotlib.pyplot as plt
14import matplotlib.collections as mcollections
15import matplotlib.transforms as mtransforms
16from matplotlib.collections import Collection, LineCollection, EventCollection
17from matplotlib.testing.decorators import image_comparison
18
19
20def generate_EventCollection_plot():
21    '''
22    generate the initial collection and plot it
23    '''
24    positions = np.array([0., 1., 2., 3., 5., 8., 13., 21.])
25    extra_positions = np.array([34., 55., 89.])
26    orientation = 'horizontal'
27    lineoffset = 1
28    linelength = .5
29    linewidth = 2
30    color = [1, 0, 0, 1]
31    linestyle = 'solid'
32    antialiased = True
33
34    coll = EventCollection(positions,
35                           orientation=orientation,
36                           lineoffset=lineoffset,
37                           linelength=linelength,
38                           linewidth=linewidth,
39                           color=color,
40                           linestyle=linestyle,
41                           antialiased=antialiased
42                           )
43
44    fig = plt.figure()
45    splt = fig.add_subplot(1, 1, 1)
46    splt.add_collection(coll)
47    splt.set_title('EventCollection: default')
48    props = {'positions': positions,
49             'extra_positions': extra_positions,
50             'orientation': orientation,
51             'lineoffset': lineoffset,
52             'linelength': linelength,
53             'linewidth': linewidth,
54             'color': color,
55             'linestyle': linestyle,
56             'antialiased': antialiased
57             }
58    splt.set_xlim(-1, 22)
59    splt.set_ylim(0, 2)
60    return splt, coll, props
61
62
63@image_comparison(baseline_images=['EventCollection_plot__default'])
64def test__EventCollection__get_segments():
65    '''
66    check to make sure the default segments have the correct coordinates
67    '''
68    _, coll, props = generate_EventCollection_plot()
69    check_segments(coll,
70                   props['positions'],
71                   props['linelength'],
72                   props['lineoffset'],
73                   props['orientation'])
74
75
76def test__EventCollection__get_positions():
77    '''
78    check to make sure the default positions match the input positions
79    '''
80    _, coll, props = generate_EventCollection_plot()
81    np.testing.assert_array_equal(props['positions'], coll.get_positions())
82
83
84def test__EventCollection__get_orientation():
85    '''
86    check to make sure the default orientation matches the input
87    orientation
88    '''
89    _, coll, props = generate_EventCollection_plot()
90    assert_equal(props['orientation'], coll.get_orientation())
91
92
93def test__EventCollection__is_horizontal():
94    '''
95    check to make sure the default orientation matches the input
96    orientation
97    '''
98    _, coll, _ = generate_EventCollection_plot()
99    assert_equal(True, coll.is_horizontal())
100
101
102def test__EventCollection__get_linelength():
103    '''
104    check to make sure the default linelength matches the input linelength
105    '''
106    _, coll, props = generate_EventCollection_plot()
107    assert_equal(props['linelength'], coll.get_linelength())
108
109
110def test__EventCollection__get_lineoffset():
111    '''
112    check to make sure the default lineoffset matches the input lineoffset
113    '''
114    _, coll, props = generate_EventCollection_plot()
115    assert_equal(props['lineoffset'], coll.get_lineoffset())
116
117
118def test__EventCollection__get_linestyle():
119    '''
120    check to make sure the default linestyle matches the input linestyle
121    '''
122    _, coll, _ = generate_EventCollection_plot()
123    assert_equal(coll.get_linestyle(), [(None, None)])
124
125
126def test__EventCollection__get_color():
127    '''
128    check to make sure the default color matches the input color
129    '''
130    _, coll, props = generate_EventCollection_plot()
131    np.testing.assert_array_equal(props['color'], coll.get_color())
132    check_allprop_array(coll.get_colors(), props['color'])
133
134
135@image_comparison(baseline_images=['EventCollection_plot__set_positions'])
136def test__EventCollection__set_positions():
137    '''
138    check to make sure set_positions works properly
139    '''
140    splt, coll, props = generate_EventCollection_plot()
141    new_positions = np.hstack([props['positions'], props['extra_positions']])
142    coll.set_positions(new_positions)
143    np.testing.assert_array_equal(new_positions, coll.get_positions())
144    check_segments(coll, new_positions,
145                   props['linelength'],
146                   props['lineoffset'],
147                   props['orientation'])
148    splt.set_title('EventCollection: set_positions')
149    splt.set_xlim(-1, 90)
150
151
152@image_comparison(baseline_images=['EventCollection_plot__add_positions'])
153def test__EventCollection__add_positions():
154    '''
155    check to make sure add_positions works properly
156    '''
157    splt, coll, props = generate_EventCollection_plot()
158    new_positions = np.hstack([props['positions'],
159                               props['extra_positions'][0]])
160    coll.add_positions(props['extra_positions'][0])
161    np.testing.assert_array_equal(new_positions, coll.get_positions())
162    check_segments(coll,
163                   new_positions,
164                   props['linelength'],
165                   props['lineoffset'],
166                   props['orientation'])
167    splt.set_title('EventCollection: add_positions')
168    splt.set_xlim(-1, 35)
169
170
171@image_comparison(baseline_images=['EventCollection_plot__append_positions'])
172def test__EventCollection__append_positions():
173    '''
174    check to make sure append_positions works properly
175    '''
176    splt, coll, props = generate_EventCollection_plot()
177    new_positions = np.hstack([props['positions'],
178                               props['extra_positions'][2]])
179    coll.append_positions(props['extra_positions'][2])
180    np.testing.assert_array_equal(new_positions, coll.get_positions())
181    check_segments(coll,
182                   new_positions,
183                   props['linelength'],
184                   props['lineoffset'],
185                   props['orientation'])
186    splt.set_title('EventCollection: append_positions')
187    splt.set_xlim(-1, 90)
188
189
190@image_comparison(baseline_images=['EventCollection_plot__extend_positions'])
191def test__EventCollection__extend_positions():
192    '''
193    check to make sure extend_positions works properly
194    '''
195    splt, coll, props = generate_EventCollection_plot()
196    new_positions = np.hstack([props['positions'],
197                               props['extra_positions'][1:]])
198    coll.extend_positions(props['extra_positions'][1:])
199    np.testing.assert_array_equal(new_positions, coll.get_positions())
200    check_segments(coll,
201                   new_positions,
202                   props['linelength'],
203                   props['lineoffset'],
204                   props['orientation'])
205    splt.set_title('EventCollection: extend_positions')
206    splt.set_xlim(-1, 90)
207
208
209@image_comparison(baseline_images=['EventCollection_plot__switch_orientation'])
210def test__EventCollection__switch_orientation():
211    '''
212    check to make sure switch_orientation works properly
213    '''
214    splt, coll, props = generate_EventCollection_plot()
215    new_orientation = 'vertical'
216    coll.switch_orientation()
217    assert_equal(new_orientation, coll.get_orientation())
218    assert_equal(False, coll.is_horizontal())
219    new_positions = coll.get_positions()
220    check_segments(coll,
221                   new_positions,
222                   props['linelength'],
223                   props['lineoffset'], new_orientation)
224    splt.set_title('EventCollection: switch_orientation')
225    splt.set_ylim(-1, 22)
226    splt.set_xlim(0, 2)
227
228
229@image_comparison(
230    baseline_images=['EventCollection_plot__switch_orientation__2x'])
231def test__EventCollection__switch_orientation_2x():
232    '''
233    check to make sure calling switch_orientation twice sets the
234    orientation back to the default
235    '''
236    splt, coll, props = generate_EventCollection_plot()
237    coll.switch_orientation()
238    coll.switch_orientation()
239    new_positions = coll.get_positions()
240    assert_equal(props['orientation'], coll.get_orientation())
241    assert_equal(True, coll.is_horizontal())
242    np.testing.assert_array_equal(props['positions'], new_positions)
243    check_segments(coll,
244                   new_positions,
245                   props['linelength'],
246                   props['lineoffset'],
247                   props['orientation'])
248    splt.set_title('EventCollection: switch_orientation 2x')
249
250
251@image_comparison(baseline_images=['EventCollection_plot__set_orientation'])
252def test__EventCollection__set_orientation():
253    '''
254    check to make sure set_orientation works properly
255    '''
256    splt, coll, props = generate_EventCollection_plot()
257    new_orientation = 'vertical'
258    coll.set_orientation(new_orientation)
259    assert_equal(new_orientation, coll.get_orientation())
260    assert_equal(False, coll.is_horizontal())
261    check_segments(coll,
262                   props['positions'],
263                   props['linelength'],
264                   props['lineoffset'],
265                   new_orientation)
266    splt.set_title('EventCollection: set_orientation')
267    splt.set_ylim(-1, 22)
268    splt.set_xlim(0, 2)
269
270
271@image_comparison(baseline_images=['EventCollection_plot__set_linelength'])
272def test__EventCollection__set_linelength():
273    '''
274    check to make sure set_linelength works properly
275    '''
276    splt, coll, props = generate_EventCollection_plot()
277    new_linelength = 15
278    coll.set_linelength(new_linelength)
279    assert_equal(new_linelength, coll.get_linelength())
280    check_segments(coll,
281                   props['positions'],
282                   new_linelength,
283                   props['lineoffset'],
284                   props['orientation'])
285    splt.set_title('EventCollection: set_linelength')
286    splt.set_ylim(-20, 20)
287
288
289@image_comparison(baseline_images=['EventCollection_plot__set_lineoffset'])
290def test__EventCollection__set_lineoffset():
291    '''
292    check to make sure set_lineoffset works properly
293    '''
294    splt, coll, props = generate_EventCollection_plot()
295    new_lineoffset = -5.
296    coll.set_lineoffset(new_lineoffset)
297    assert_equal(new_lineoffset, coll.get_lineoffset())
298    check_segments(coll,
299                   props['positions'],
300                   props['linelength'],
301                   new_lineoffset,
302                   props['orientation'])
303    splt.set_title('EventCollection: set_lineoffset')
304    splt.set_ylim(-6, -4)
305
306
307@image_comparison(baseline_images=['EventCollection_plot__set_linestyle'])
308def test__EventCollection__set_linestyle():
309    '''
310    check to make sure set_linestyle works properly
311    '''
312    splt, coll, _ = generate_EventCollection_plot()
313    new_linestyle = 'dashed'
314    coll.set_linestyle(new_linestyle)
315    assert_equal(coll.get_linestyle(), [(0, (6.0, 6.0))])
316    splt.set_title('EventCollection: set_linestyle')
317
318
319@image_comparison(baseline_images=['EventCollection_plot__set_ls_dash'],
320                  remove_text=True)
321def test__EventCollection__set_linestyle_single_dash():
322    '''
323    check to make sure set_linestyle accepts a single dash pattern
324    '''
325    splt, coll, _ = generate_EventCollection_plot()
326    new_linestyle = (0, (6., 6.))
327    coll.set_linestyle(new_linestyle)
328    assert_equal(coll.get_linestyle(), [(0, (6.0, 6.0))])
329    splt.set_title('EventCollection: set_linestyle')
330
331
332@image_comparison(baseline_images=['EventCollection_plot__set_linewidth'])
333def test__EventCollection__set_linewidth():
334    '''
335    check to make sure set_linestyle works properly
336    '''
337    splt, coll, _ = generate_EventCollection_plot()
338    new_linewidth = 5
339    coll.set_linewidth(new_linewidth)
340    assert_equal(coll.get_linewidth(), new_linewidth)
341    splt.set_title('EventCollection: set_linewidth')
342
343
344@image_comparison(baseline_images=['EventCollection_plot__set_color'])
345def test__EventCollection__set_color():
346    '''
347    check to make sure set_color works properly
348    '''
349    splt, coll, _ = generate_EventCollection_plot()
350    new_color = np.array([0, 1, 1, 1])
351    coll.set_color(new_color)
352    np.testing.assert_array_equal(new_color, coll.get_color())
353    check_allprop_array(coll.get_colors(), new_color)
354    splt.set_title('EventCollection: set_color')
355
356
357def check_segments(coll, positions, linelength, lineoffset, orientation):
358    '''
359    check to make sure all values in the segment are correct, given a
360    particular set of inputs
361
362    note: this is not a test, it is used by tests
363    '''
364    segments = coll.get_segments()
365    if (orientation.lower() == 'horizontal'
366            or orientation.lower() == 'none' or orientation is None):
367        # if horizontal, the position in is in the y-axis
368        pos1 = 1
369        pos2 = 0
370    elif orientation.lower() == 'vertical':
371        # if vertical, the position in is in the x-axis
372        pos1 = 0
373        pos2 = 1
374    else:
375        raise ValueError("orientation must be 'horizontal' or 'vertical'")
376
377    # test to make sure each segment is correct
378    for i, segment in enumerate(segments):
379        assert_equal(segment[0, pos1], lineoffset + linelength / 2.)
380        assert_equal(segment[1, pos1], lineoffset - linelength / 2.)
381        assert_equal(segment[0, pos2], positions[i])
382        assert_equal(segment[1, pos2], positions[i])
383
384
385def check_allprop_array(values, target):
386    '''
387    check to make sure all values match the given target if arrays
388
389    note: this is not a test, it is used by tests
390    '''
391    for value in values:
392        np.testing.assert_array_equal(value, target)
393
394
395def test_null_collection_datalim():
396    col = mcollections.PathCollection([])
397    col_data_lim = col.get_datalim(mtransforms.IdentityTransform())
398    assert_array_equal(col_data_lim.get_points(),
399                       mtransforms.Bbox.null().get_points())
400
401
402def test_add_collection():
403    # Test if data limits are unchanged by adding an empty collection.
404    # Github issue #1490, pull #1497.
405    plt.figure()
406    ax = plt.axes()
407    coll = ax.scatter([0, 1], [0, 1])
408    ax.add_collection(coll)
409    bounds = ax.dataLim.bounds
410    coll = ax.scatter([], [])
411    assert_equal(ax.dataLim.bounds, bounds)
412
413
414def test_quiver_limits():
415    ax = plt.axes()
416    x, y = np.arange(8), np.arange(10)
417    u = v = np.linspace(0, 10, 80).reshape(10, 8)
418    q = plt.quiver(x, y, u, v)
419    assert_equal(q.get_datalim(ax.transData).bounds, (0., 0., 7., 9.))
420
421    plt.figure()
422    ax = plt.axes()
423    x = np.linspace(-5, 10, 20)
424    y = np.linspace(-2, 4, 10)
425    y, x = np.meshgrid(y, x)
426    trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
427    plt.quiver(x, y, np.sin(x), np.cos(y), transform=trans)
428    assert_equal(ax.dataLim.bounds, (20.0, 30.0, 15.0, 6.0))
429
430
431def test_barb_limits():
432    ax = plt.axes()
433    x = np.linspace(-5, 10, 20)
434    y = np.linspace(-2, 4, 10)
435    y, x = np.meshgrid(y, x)
436    trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
437    plt.barbs(x, y, np.sin(x), np.cos(y), transform=trans)
438    # The calculated bounds are approximately the bounds of the original data,
439    # this is because the entire path is taken into account when updating the
440    # datalim.
441    assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6),
442                              decimal=1)
443
444
445@image_comparison(baseline_images=['EllipseCollection_test_image'],
446                  extensions=['png'],
447                  remove_text=True)
448def test_EllipseCollection():
449    # Test basic functionality
450    fig, ax = plt.subplots()
451    x = np.arange(4)
452    y = np.arange(3)
453    X, Y = np.meshgrid(x, y)
454    XY = np.vstack((X.ravel(), Y.ravel())).T
455
456    ww = X / x[-1]
457    hh = Y / y[-1]
458    aa = np.ones_like(ww) * 20  # first axis is 20 degrees CCW from x axis
459
460    ec = mcollections.EllipseCollection(ww, hh, aa,
461                                        units='x',
462                                        offsets=XY,
463                                        transOffset=ax.transData,
464                                        facecolors='none')
465    ax.add_collection(ec)
466    ax.autoscale_view()
467
468
469@image_comparison(baseline_images=['polycollection_close'],
470                  extensions=['png'], remove_text=True)
471def test_polycollection_close():
472    from mpl_toolkits.mplot3d import Axes3D
473
474    vertsQuad = [
475        [[0., 0.], [0., 1.], [1., 1.], [1., 0.]],
476        [[0., 1.], [2., 3.], [2., 2.], [1., 1.]],
477        [[2., 2.], [2., 3.], [4., 1.], [3., 1.]],
478        [[3., 0.], [3., 1.], [4., 1.], [4., 0.]]]
479
480    fig = plt.figure()
481    ax = Axes3D(fig)
482
483    colors = ['r', 'g', 'b', 'y', 'k']
484    zpos = list(range(5))
485
486    poly = mcollections.PolyCollection(
487        vertsQuad * len(zpos), linewidth=0.25)
488    poly.set_alpha(0.7)
489
490    # need to have a z-value for *each* polygon = element!
491    zs = []
492    cs = []
493    for z, c in zip(zpos, colors):
494        zs.extend([z] * len(vertsQuad))
495        cs.extend([c] * len(vertsQuad))
496
497    poly.set_color(cs)
498
499    ax.add_collection3d(poly, zs=zs, zdir='y')
500
501    # axis limit settings:
502    ax.set_xlim3d(0, 4)
503    ax.set_zlim3d(0, 3)
504    ax.set_ylim3d(0, 4)
505
506
507@image_comparison(baseline_images=['regularpolycollection_rotate'],
508                  extensions=['png'], remove_text=True)
509def test_regularpolycollection_rotate():
510    xx, yy = np.mgrid[:10, :10]
511    xy_points = np.transpose([xx.flatten(), yy.flatten()])
512    rotations = np.linspace(0, 2*np.pi, len(xy_points))
513
514    fig, ax = plt.subplots()
515    for xy, alpha in zip(xy_points, rotations):
516        col = mcollections.RegularPolyCollection(
517            4, sizes=(100,), rotation=alpha,
518            offsets=[xy], transOffset=ax.transData)
519        ax.add_collection(col, autolim=True)
520    ax.autoscale_view()
521
522
523@image_comparison(baseline_images=['regularpolycollection_scale'],
524                  extensions=['png'], remove_text=True)
525def test_regularpolycollection_scale():
526    # See issue #3860
527
528    class SquareCollection(mcollections.RegularPolyCollection):
529        def __init__(self, **kwargs):
530            super(SquareCollection, self).__init__(
531                4, rotation=np.pi/4., **kwargs)
532
533        def get_transform(self):
534            """Return transform scaling circle areas to data space."""
535            ax = self.axes
536
537            pts2pixels = 72.0 / ax.figure.dpi
538
539            scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width
540            scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height
541            return mtransforms.Affine2D().scale(scale_x, scale_y)
542
543    fig, ax = plt.subplots()
544
545    xy = [(0, 0)]
546    # Unit square has a half-diagonal of `1 / sqrt(2)`, so `pi * r**2`
547    # equals...
548    circle_areas = [np.pi / 2]
549    squares = SquareCollection(sizes=circle_areas, offsets=xy,
550                               transOffset=ax.transData)
551    ax.add_collection(squares, autolim=True)
552    ax.axis([-1, 1, -1, 1])
553
554
555def test_picking():
556    fig, ax = plt.subplots()
557    col = ax.scatter([0], [0], [1000], picker=True)
558    fig.savefig(io.BytesIO(), dpi=fig.dpi)
559
560    class MouseEvent(object):
561        pass
562    event = MouseEvent()
563    event.x = 325
564    event.y = 240
565
566    found, indices = col.contains(event)
567    assert found
568    assert_array_equal(indices['ind'], [0])
569
570
571def test_linestyle_single_dashes():
572    plt.scatter([0, 1, 2], [0, 1, 2], linestyle=(0., [2., 2.]))
573    plt.draw()
574
575
576@image_comparison(baseline_images=['size_in_xy'], remove_text=True,
577                  extensions=['png'])
578def test_size_in_xy():
579    fig, ax = plt.subplots()
580
581    widths, heights, angles = (10, 10), 10, 0
582    widths = 10, 10
583    coords = [(10, 10), (15, 15)]
584    e = mcollections.EllipseCollection(
585        widths, heights, angles,
586        units='xy',
587        offsets=coords,
588        transOffset=ax.transData)
589
590    ax.add_collection(e)
591
592    ax.set_xlim(0, 30)
593    ax.set_ylim(0, 30)
594
595
596def test_pandas_indexing(pd):
597
598    # Should not fail break when faced with a
599    # non-zero indexed series
600    index = [11, 12, 13]
601    ec = fc = pd.Series(['red', 'blue', 'green'], index=index)
602    lw = pd.Series([1, 2, 3], index=index)
603    ls = pd.Series(['solid', 'dashed', 'dashdot'], index=index)
604    aa = pd.Series([True, False, True], index=index)
605
606    Collection(edgecolors=ec)
607    Collection(facecolors=fc)
608    Collection(linewidths=lw)
609    Collection(linestyles=ls)
610    Collection(antialiaseds=aa)
611
612
613@pytest.mark.style('default')
614def test_lslw_bcast():
615    col = mcollections.PathCollection([])
616    col.set_linestyles(['-', '-'])
617    col.set_linewidths([1, 2, 3])
618
619    assert_equal(col.get_linestyles(), [(None, None)] * 6)
620    assert_equal(col.get_linewidths(), [1, 2, 3] * 2)
621
622    col.set_linestyles(['-', '-', '-'])
623    assert_equal(col.get_linestyles(), [(None, None)] * 3)
624    assert_equal(col.get_linewidths(), [1, 2, 3])
625
626
627@pytest.mark.style('default')
628def test_capstyle():
629    col = mcollections.PathCollection([], capstyle='round')
630    assert_equal(col.get_capstyle(), 'round')
631    col.set_capstyle('butt')
632    assert_equal(col.get_capstyle(), 'butt')
633
634
635@pytest.mark.style('default')
636def test_joinstyle():
637    col = mcollections.PathCollection([], joinstyle='round')
638    assert_equal(col.get_joinstyle(), 'round')
639    col.set_joinstyle('miter')
640    assert_equal(col.get_joinstyle(), 'miter')
641
642
643@image_comparison(baseline_images=['cap_and_joinstyle'],
644                  extensions=['png'])
645def test_cap_and_joinstyle_image():
646    fig = plt.figure()
647    ax = fig.add_subplot(1, 1, 1)
648    ax.set_xlim([-0.5, 1.5])
649    ax.set_ylim([-0.5, 2.5])
650
651    x = np.array([0.0, 1.0, 0.5])
652    ys = np.array([[0.0], [0.5], [1.0]]) + np.array([[0.0, 0.0, 1.0]])
653
654    segs = np.zeros((3, 3, 2))
655    segs[:, :, 0] = x
656    segs[:, :, 1] = ys
657    line_segments = LineCollection(segs, linewidth=[10, 15, 20])
658    line_segments.set_capstyle("round")
659    line_segments.set_joinstyle("miter")
660
661    ax.add_collection(line_segments)
662    ax.set_title('Line collection with customized caps and joinstyle')
663
664
665@image_comparison(baseline_images=['scatter_post_alpha'],
666                  extensions=['png'], remove_text=True,
667                  style='default')
668def test_scatter_post_alpha():
669    fig, ax = plt.subplots()
670    sc = ax.scatter(range(5), range(5), c=range(5))
671    # this needs to be here to update internal state
672    fig.canvas.draw()
673    sc.set_alpha(.1)
674