1from __future__ import absolute_import, division, print_function
2
3import six
4
5from copy import copy
6import io
7import os
8import sys
9import warnings
10
11import numpy as np
12from numpy import ma
13from numpy.testing import assert_array_equal
14
15from matplotlib import (
16    colors, image as mimage, patches, pyplot as plt,
17    rc_context, rcParams)
18from matplotlib.image import (AxesImage, BboxImage, FigureImage,
19                              NonUniformImage, PcolorImage)
20from matplotlib.testing.decorators import image_comparison
21from matplotlib.transforms import Bbox, Affine2D, TransformedBbox
22
23import pytest
24
25
26@image_comparison(baseline_images=['image_interps'], style='mpl20')
27def test_image_interps():
28    'make the basic nearest, bilinear and bicubic interps'
29    X = np.arange(100)
30    X = X.reshape(5, 20)
31
32    fig = plt.figure()
33    ax1 = fig.add_subplot(311)
34    ax1.imshow(X, interpolation='nearest')
35    ax1.set_title('three interpolations')
36    ax1.set_ylabel('nearest')
37
38    ax2 = fig.add_subplot(312)
39    ax2.imshow(X, interpolation='bilinear')
40    ax2.set_ylabel('bilinear')
41
42    ax3 = fig.add_subplot(313)
43    ax3.imshow(X, interpolation='bicubic')
44    ax3.set_ylabel('bicubic')
45
46
47@image_comparison(baseline_images=['interp_nearest_vs_none'],
48                  extensions=['pdf', 'svg'], remove_text=True)
49def test_interp_nearest_vs_none():
50    'Test the effect of "nearest" and "none" interpolation'
51    # Setting dpi to something really small makes the difference very
52    # visible. This works fine with pdf, since the dpi setting doesn't
53    # affect anything but images, but the agg output becomes unusably
54    # small.
55    rcParams['savefig.dpi'] = 3
56    X = np.array([[[218, 165, 32], [122, 103, 238]],
57                  [[127, 255, 0], [255, 99, 71]]], dtype=np.uint8)
58    fig = plt.figure()
59    ax1 = fig.add_subplot(121)
60    ax1.imshow(X, interpolation='none')
61    ax1.set_title('interpolation none')
62    ax2 = fig.add_subplot(122)
63    ax2.imshow(X, interpolation='nearest')
64    ax2.set_title('interpolation nearest')
65
66
67def do_figimage(suppressComposite):
68    """ Helper for the next two tests """
69    fig = plt.figure(figsize=(2,2), dpi=100)
70    fig.suppressComposite = suppressComposite
71    x,y = np.ix_(np.arange(100.0)/100.0, np.arange(100.0)/100.0)
72    z = np.sin(x**2 + y**2 - x*y)
73    c = np.sin(20*x**2 + 50*y**2)
74    img = z + c/5
75
76    fig.figimage(img, xo=0, yo=0, origin='lower')
77    fig.figimage(img[::-1,:], xo=0, yo=100, origin='lower')
78    fig.figimage(img[:,::-1], xo=100, yo=0, origin='lower')
79    fig.figimage(img[::-1,::-1], xo=100, yo=100, origin='lower')
80
81
82@image_comparison(baseline_images=['figimage-0'],
83                  extensions=['png','pdf'])
84def test_figimage0():
85    'test the figimage method'
86
87    suppressComposite = False
88    do_figimage(suppressComposite)
89
90
91@image_comparison(baseline_images=['figimage-1'],
92                  extensions=['png','pdf'])
93def test_figimage1():
94    'test the figimage method'
95    suppressComposite = True
96    do_figimage(suppressComposite)
97
98
99def test_image_python_io():
100    fig, ax = plt.subplots()
101    ax.plot([1,2,3])
102    buffer = io.BytesIO()
103    fig.savefig(buffer)
104    buffer.seek(0)
105    plt.imread(buffer)
106
107
108def test_imread_pil_uint16():
109    pytest.importorskip("PIL")
110    img = plt.imread(os.path.join(os.path.dirname(__file__),
111                     'baseline_images', 'test_image', 'uint16.tif'))
112    assert (img.dtype == np.uint16)
113    assert np.sum(img) == 134184960
114
115
116@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires Python 3.6+")
117def test_imread_fspath():
118    pytest.importorskip("PIL")
119    from pathlib import Path
120    img = plt.imread(
121        Path(__file__).parent / 'baseline_images/test_image/uint16.tif')
122    assert img.dtype == np.uint16
123    assert np.sum(img) == 134184960
124
125
126def test_imsave():
127    # The goal here is that the user can specify an output logical DPI
128    # for the image, but this will not actually add any extra pixels
129    # to the image, it will merely be used for metadata purposes.
130
131    # So we do the traditional case (dpi == 1), and the new case (dpi
132    # == 100) and read the resulting PNG files back in and make sure
133    # the data is 100% identical.
134    np.random.seed(1)
135    data = np.random.rand(256, 128)
136
137    buff_dpi1 = io.BytesIO()
138    plt.imsave(buff_dpi1, data, dpi=1)
139
140    buff_dpi100 = io.BytesIO()
141    plt.imsave(buff_dpi100, data, dpi=100)
142
143    buff_dpi1.seek(0)
144    arr_dpi1 = plt.imread(buff_dpi1)
145
146    buff_dpi100.seek(0)
147    arr_dpi100 = plt.imread(buff_dpi100)
148
149    assert arr_dpi1.shape == (256, 128, 4)
150    assert arr_dpi100.shape == (256, 128, 4)
151
152    assert_array_equal(arr_dpi1, arr_dpi100)
153
154
155@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires Python 3.6+")
156@pytest.mark.parametrize("fmt", ["png", "pdf", "ps", "eps", "svg"])
157def test_imsave_fspath(fmt):
158    Path = pytest.importorskip("pathlib").Path
159    plt.imsave(Path(os.devnull), np.array([[0, 1]]), format=fmt)
160
161
162def test_imsave_color_alpha():
163    # Test that imsave accept arrays with ndim=3 where the third dimension is
164    # color and alpha without raising any exceptions, and that the data is
165    # acceptably preserved through a save/read roundtrip.
166    np.random.seed(1)
167
168    for origin in ['lower', 'upper']:
169        data = np.random.rand(16, 16, 4)
170        buff = io.BytesIO()
171        plt.imsave(buff, data, origin=origin, format="png")
172
173        buff.seek(0)
174        arr_buf = plt.imread(buff)
175
176        # Recreate the float -> uint8 conversion of the data
177        # We can only expect to be the same with 8 bits of precision,
178        # since that's what the PNG file used.
179        data = (255*data).astype('uint8')
180        if origin == 'lower':
181            data = data[::-1]
182        arr_buf = (255*arr_buf).astype('uint8')
183
184        assert_array_equal(data, arr_buf)
185
186
187@image_comparison(baseline_images=['image_alpha'], remove_text=True)
188def test_image_alpha():
189    plt.figure()
190
191    np.random.seed(0)
192    Z = np.random.rand(6, 6)
193
194    plt.subplot(131)
195    plt.imshow(Z, alpha=1.0, interpolation='none')
196
197    plt.subplot(132)
198    plt.imshow(Z, alpha=0.5, interpolation='none')
199
200    plt.subplot(133)
201    plt.imshow(Z, alpha=0.5, interpolation='nearest')
202
203def test_cursor_data():
204    from matplotlib.backend_bases import MouseEvent
205
206    fig, ax = plt.subplots()
207    im = ax.imshow(np.arange(100).reshape(10, 10), origin='upper')
208
209    x, y = 4, 4
210    xdisp, ydisp = ax.transData.transform_point([x, y])
211
212    event = MouseEvent('motion_notify_event', fig.canvas, xdisp, ydisp)
213    assert im.get_cursor_data(event) == 44
214
215    # Now try for a point outside the image
216    # Tests issue #4957
217    x, y = 10.1, 4
218    xdisp, ydisp = ax.transData.transform_point([x, y])
219
220    event = MouseEvent('motion_notify_event', fig.canvas, xdisp, ydisp)
221    assert im.get_cursor_data(event) is None
222
223    # Hmm, something is wrong here... I get 0, not None...
224    # But, this works further down in the tests with extents flipped
225    #x, y = 0.1, -0.1
226    #xdisp, ydisp = ax.transData.transform_point([x, y])
227    #event = MouseEvent('motion_notify_event', fig.canvas, xdisp, ydisp)
228    #z = im.get_cursor_data(event)
229    #assert z is None, "Did not get None, got %d" % z
230
231    ax.clear()
232    # Now try with the extents flipped.
233    im = ax.imshow(np.arange(100).reshape(10, 10), origin='lower')
234
235    x, y = 4, 4
236    xdisp, ydisp = ax.transData.transform_point([x, y])
237
238    event = MouseEvent('motion_notify_event', fig.canvas, xdisp, ydisp)
239    assert im.get_cursor_data(event) == 44
240
241    fig, ax = plt.subplots()
242    im = ax.imshow(np.arange(100).reshape(10, 10), extent=[0, 0.5, 0, 0.5])
243
244    x, y = 0.25, 0.25
245    xdisp, ydisp = ax.transData.transform_point([x, y])
246
247    event = MouseEvent('motion_notify_event', fig.canvas, xdisp, ydisp)
248    assert im.get_cursor_data(event) == 55
249
250    # Now try for a point outside the image
251    # Tests issue #4957
252    x, y = 0.75, 0.25
253    xdisp, ydisp = ax.transData.transform_point([x, y])
254
255    event = MouseEvent('motion_notify_event', fig.canvas, xdisp, ydisp)
256    assert im.get_cursor_data(event) is None
257
258    x, y = 0.01, -0.01
259    xdisp, ydisp = ax.transData.transform_point([x, y])
260
261    event = MouseEvent('motion_notify_event', fig.canvas, xdisp, ydisp)
262    assert im.get_cursor_data(event) is None
263
264
265@image_comparison(baseline_images=['image_clip'], style='mpl20')
266def test_image_clip():
267    d = [[1, 2], [3, 4]]
268
269    fig, ax = plt.subplots()
270    im = ax.imshow(d)
271    patch = patches.Circle((0, 0), radius=1, transform=ax.transData)
272    im.set_clip_path(patch)
273
274
275@image_comparison(baseline_images=['image_cliprect'], style='mpl20')
276def test_image_cliprect():
277    import matplotlib.patches as patches
278
279    fig, ax = plt.subplots()
280    d = [[1,2],[3,4]]
281
282    im = ax.imshow(d, extent=(0,5,0,5))
283
284    rect = patches.Rectangle(xy=(1,1), width=2, height=2, transform=im.axes.transData)
285    im.set_clip_path(rect)
286
287
288@image_comparison(baseline_images=['imshow'], remove_text=True, style='mpl20')
289def test_imshow():
290    fig, ax = plt.subplots()
291    arr = np.arange(100).reshape((10, 10))
292    ax.imshow(arr, interpolation="bilinear", extent=(1,2,1,2))
293    ax.set_xlim(0,3)
294    ax.set_ylim(0,3)
295
296
297@image_comparison(baseline_images=['no_interpolation_origin'],
298                  remove_text=True)
299def test_no_interpolation_origin():
300    fig = plt.figure()
301    ax = fig.add_subplot(211)
302    ax.imshow(np.arange(100).reshape((2, 50)), origin="lower",
303              interpolation='none')
304
305    ax = fig.add_subplot(212)
306    ax.imshow(np.arange(100).reshape((2, 50)), interpolation='none')
307
308
309@image_comparison(baseline_images=['image_shift'], remove_text=True,
310                  extensions=['pdf', 'svg'])
311def test_image_shift():
312    from matplotlib.colors import LogNorm
313
314    imgData = [[1.0/(x) + 1.0/(y) for x in range(1,100)] for y in range(1,100)]
315    tMin=734717.945208
316    tMax=734717.946366
317
318    fig, ax = plt.subplots()
319    ax.imshow(imgData, norm=LogNorm(), interpolation='none',
320              extent=(tMin, tMax, 1, 100))
321    ax.set_aspect('auto')
322
323
324def test_image_edges():
325    f = plt.figure(figsize=[1, 1])
326    ax = f.add_axes([0, 0, 1, 1], frameon=False)
327
328    data = np.tile(np.arange(12), 15).reshape(20, 9)
329
330    im = ax.imshow(data, origin='upper', extent=[-10, 10, -10, 10],
331                   interpolation='none', cmap='gray')
332
333    x = y = 2
334    ax.set_xlim([-x, x])
335    ax.set_ylim([-y, y])
336
337    ax.set_xticks([])
338    ax.set_yticks([])
339
340    buf = io.BytesIO()
341    f.savefig(buf, facecolor=(0, 1, 0))
342
343    buf.seek(0)
344
345    im = plt.imread(buf)
346    r, g, b, a = sum(im[:, 0])
347    r, g, b, a = sum(im[:, -1])
348
349    assert g != 100, 'Expected a non-green edge - but sadly, it was.'
350
351
352@image_comparison(baseline_images=['image_composite_background'],
353                  remove_text=True,
354                  style='mpl20')
355def test_image_composite_background():
356    fig, ax = plt.subplots()
357    arr = np.arange(12).reshape(4, 3)
358    ax.imshow(arr, extent=[0, 2, 15, 0])
359    ax.imshow(arr, extent=[4, 6, 15, 0])
360    ax.set_facecolor((1, 0, 0, 0.5))
361    ax.set_xlim([0, 12])
362
363
364@image_comparison(baseline_images=['image_composite_alpha'],
365                  remove_text=True)
366def test_image_composite_alpha():
367    """
368    Tests that the alpha value is recognized and correctly applied in the
369    process of compositing images together.
370    """
371    fig, ax = plt.subplots()
372    arr = np.zeros((11, 21, 4))
373    arr[:, :, 0] = 1
374    arr[:, :, 3] = np.concatenate((np.arange(0, 1.1, 0.1), np.arange(0, 1, 0.1)[::-1]))
375    arr2 = np.zeros((21, 11, 4))
376    arr2[:, :, 0] = 1
377    arr2[:, :, 1] = 1
378    arr2[:, :, 3] = np.concatenate((np.arange(0, 1.1, 0.1), np.arange(0, 1, 0.1)[::-1]))[:, np.newaxis]
379    ax.imshow(arr, extent=[1, 2, 5, 0], alpha=0.3)
380    ax.imshow(arr, extent=[2, 3, 5, 0], alpha=0.6)
381    ax.imshow(arr, extent=[3, 4, 5, 0])
382    ax.imshow(arr2, extent=[0, 5, 1, 2])
383    ax.imshow(arr2, extent=[0, 5, 2, 3], alpha=0.6)
384    ax.imshow(arr2, extent=[0, 5, 3, 4], alpha=0.3)
385    ax.set_facecolor((0, 0.5, 0, 1))
386    ax.set_xlim([0, 5])
387    ax.set_ylim([5, 0])
388
389
390@image_comparison(baseline_images=['rasterize_10dpi'],
391                  extensions=['pdf', 'svg'],
392                  remove_text=True, style='mpl20')
393def test_rasterize_dpi():
394    # This test should check rasterized rendering with high output resolution.
395    # It plots a rasterized line and a normal image with implot. So it will catch
396    # when images end up in the wrong place in case of non-standard dpi setting.
397    # Instead of high-res rasterization i use low-res.  Therefore the fact that the
398    # resolution is non-standard is easily checked by image_comparison.
399    img = np.asarray([[1, 2], [3, 4]])
400
401    fig, axes = plt.subplots(1, 3, figsize = (3, 1))
402
403    axes[0].imshow(img)
404
405    axes[1].plot([0,1],[0,1], linewidth=20., rasterized=True)
406    axes[1].set(xlim = (0,1), ylim = (-1, 2))
407
408    axes[2].plot([0,1],[0,1], linewidth=20.)
409    axes[2].set(xlim = (0,1), ylim = (-1, 2))
410
411    # Low-dpi PDF rasterization errors prevent proper image comparison tests.
412    # Hide detailed structures like the axes spines.
413    for ax in axes:
414        ax.set_xticks([])
415        ax.set_yticks([])
416        for spine in ax.spines.values():
417            spine.set_visible(False)
418
419    rcParams['savefig.dpi'] = 10
420
421
422@image_comparison(baseline_images=['bbox_image_inverted'], remove_text=True,
423                  style='mpl20')
424def test_bbox_image_inverted():
425    # This is just used to produce an image to feed to BboxImage
426    image = np.arange(100).reshape((10, 10))
427
428    fig, ax = plt.subplots()
429    bbox_im = BboxImage(
430        TransformedBbox(Bbox([[100, 100], [0, 0]]), ax.transData))
431    bbox_im.set_data(image)
432    bbox_im.set_clip_on(False)
433    ax.set_xlim(0, 100)
434    ax.set_ylim(0, 100)
435    ax.add_artist(bbox_im)
436
437    image = np.identity(10)
438
439    bbox_im = BboxImage(
440        TransformedBbox(Bbox([[0.1, 0.2], [0.3, 0.25]]), ax.figure.transFigure))
441    bbox_im.set_data(image)
442    bbox_im.set_clip_on(False)
443    ax.add_artist(bbox_im)
444
445
446def test_get_window_extent_for_AxisImage():
447    # Create a figure of known size (1000x1000 pixels), place an image
448    # object at a given location and check that get_window_extent()
449    # returns the correct bounding box values (in pixels).
450
451    im = np.array([[0.25, 0.75, 1.0, 0.75], [0.1, 0.65, 0.5, 0.4],
452                   [0.6, 0.3, 0.0, 0.2], [0.7, 0.9, 0.4, 0.6]])
453    fig, ax = plt.subplots(figsize=(10, 10), dpi=100)
454    ax.set_position([0, 0, 1, 1])
455    ax.set_xlim(0, 1)
456    ax.set_ylim(0, 1)
457    im_obj = ax.imshow(im, extent=[0.4, 0.7, 0.2, 0.9], interpolation='nearest')
458
459    fig.canvas.draw()
460    renderer = fig.canvas.renderer
461    im_bbox = im_obj.get_window_extent(renderer)
462
463    assert_array_equal(im_bbox.get_points(), [[400, 200], [700, 900]])
464
465
466@image_comparison(baseline_images=['zoom_and_clip_upper_origin'],
467                  remove_text=True,
468                  extensions=['png'],
469                  style='mpl20')
470def test_zoom_and_clip_upper_origin():
471    image = np.arange(100)
472    image = image.reshape((10, 10))
473
474    fig, ax = plt.subplots()
475    ax.imshow(image)
476    ax.set_ylim(2.0, -0.5)
477    ax.set_xlim(-0.5, 2.0)
478
479
480def test_nonuniformimage_setcmap():
481    ax = plt.gca()
482    im = NonUniformImage(ax)
483    im.set_cmap('Blues')
484
485
486def test_nonuniformimage_setnorm():
487    ax = plt.gca()
488    im = NonUniformImage(ax)
489    im.set_norm(plt.Normalize())
490
491
492def test_jpeg_2d():
493    Image = pytest.importorskip('PIL.Image')
494    # smoke test that mode-L pillow images work.
495    imd = np.ones((10, 10), dtype='uint8')
496    for i in range(10):
497        imd[i, :] = np.linspace(0.0, 1.0, 10) * 255
498    im = Image.new('L', (10, 10))
499    im.putdata(imd.flatten())
500    fig, ax = plt.subplots()
501    ax.imshow(im)
502
503
504def test_jpeg_alpha():
505    Image = pytest.importorskip('PIL.Image')
506
507    plt.figure(figsize=(1, 1), dpi=300)
508    # Create an image that is all black, with a gradient from 0-1 in
509    # the alpha channel from left to right.
510    im = np.zeros((300, 300, 4), dtype=float)
511    im[..., 3] = np.linspace(0.0, 1.0, 300)
512
513    plt.figimage(im)
514
515    buff = io.BytesIO()
516    with rc_context({'savefig.facecolor': 'red'}):
517        plt.savefig(buff, transparent=True, format='jpg', dpi=300)
518
519    buff.seek(0)
520    image = Image.open(buff)
521
522    # If this fails, there will be only one color (all black). If this
523    # is working, we should have all 256 shades of grey represented.
524    num_colors = len(image.getcolors(256))
525    assert 175 <= num_colors <= 185
526    # The fully transparent part should be red.
527    corner_pixel = image.getpixel((0, 0))
528    assert corner_pixel == (254, 0, 0)
529
530
531def test_nonuniformimage_setdata():
532    ax = plt.gca()
533    im = NonUniformImage(ax)
534    x = np.arange(3, dtype=float)
535    y = np.arange(4, dtype=float)
536    z = np.arange(12, dtype=float).reshape((4, 3))
537    im.set_data(x, y, z)
538    x[0] = y[0] = z[0, 0] = 9.9
539    assert im._A[0, 0] == im._Ax[0] == im._Ay[0] == 0, 'value changed'
540
541
542def test_axesimage_setdata():
543    ax = plt.gca()
544    im = AxesImage(ax)
545    z = np.arange(12, dtype=float).reshape((4, 3))
546    im.set_data(z)
547    z[0, 0] = 9.9
548    assert im._A[0, 0] == 0, 'value changed'
549
550
551def test_figureimage_setdata():
552    fig = plt.gcf()
553    im = FigureImage(fig)
554    z = np.arange(12, dtype=float).reshape((4, 3))
555    im.set_data(z)
556    z[0, 0] = 9.9
557    assert im._A[0, 0] == 0, 'value changed'
558
559
560def test_pcolorimage_setdata():
561    ax = plt.gca()
562    im = PcolorImage(ax)
563    x = np.arange(3, dtype=float)
564    y = np.arange(4, dtype=float)
565    z = np.arange(6, dtype=float).reshape((3, 2))
566    im.set_data(x, y, z)
567    x[0] = y[0] = z[0, 0] = 9.9
568    assert im._A[0, 0] == im._Ax[0] == im._Ay[0] == 0, 'value changed'
569
570
571def test_pcolorimage_extent():
572    im = plt.hist2d([1, 2, 3], [3, 5, 6],
573                    bins=[[0, 3, 7], [1, 2, 3]])[-1]
574    assert im.get_extent() == (0, 7, 1, 3)
575
576
577def test_minimized_rasterized():
578    # This ensures that the rasterized content in the colorbars is
579    # only as thick as the colorbar, and doesn't extend to other parts
580    # of the image.  See #5814.  While the original bug exists only
581    # in Postscript, the best way to detect it is to generate SVG
582    # and then parse the output to make sure the two colorbar images
583    # are the same size.
584    from xml.etree import ElementTree
585
586    np.random.seed(0)
587    data = np.random.rand(10, 10)
588
589    fig, ax = plt.subplots(1, 2)
590    p1 = ax[0].pcolormesh(data)
591    p2 = ax[1].pcolormesh(data)
592
593    plt.colorbar(p1, ax=ax[0])
594    plt.colorbar(p2, ax=ax[1])
595
596    buff = io.BytesIO()
597    plt.savefig(buff, format='svg')
598
599    buff = io.BytesIO(buff.getvalue())
600    tree = ElementTree.parse(buff)
601    width = None
602    for image in tree.iter('image'):
603        if width is None:
604            width = image['width']
605        else:
606            if image['width'] != width:
607                assert False
608
609
610@pytest.mark.network
611def test_load_from_url():
612    req = six.moves.urllib.request.urlopen(
613        "http://matplotlib.org/_static/logo_sidebar_horiz.png")
614    plt.imread(req)
615
616
617@image_comparison(baseline_images=['log_scale_image'],
618                  remove_text=True)
619# The recwarn fixture captures a warning in image_comparison.
620def test_log_scale_image(recwarn):
621    Z = np.zeros((10, 10))
622    Z[::2] = 1
623
624    fig, ax = plt.subplots()
625    ax.imshow(Z, extent=[1, 100, 1, 100], cmap='viridis',
626              vmax=1, vmin=-1)
627    ax.set_yscale('log')
628
629
630@image_comparison(baseline_images=['rotate_image'],
631                  remove_text=True)
632def test_rotate_image():
633    delta = 0.25
634    x = y = np.arange(-3.0, 3.0, delta)
635    X, Y = np.meshgrid(x, y)
636    Z1 = np.exp(-(X**2 + Y**2) / 2) / (2 * np.pi)
637    Z2 = (np.exp(-(((X - 1) / 1.5)**2 + ((Y - 1) / 0.5)**2) / 2) /
638          (2 * np.pi * 0.5 * 1.5))
639    Z = Z2 - Z1  # difference of Gaussians
640
641    fig, ax1 = plt.subplots(1, 1)
642    im1 = ax1.imshow(Z, interpolation='none', cmap='viridis',
643                     origin='lower',
644                     extent=[-2, 4, -3, 2], clip_on=True)
645
646    trans_data2 = Affine2D().rotate_deg(30) + ax1.transData
647    im1.set_transform(trans_data2)
648
649    # display intended extent of the image
650    x1, x2, y1, y2 = im1.get_extent()
651
652    ax1.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], "r--", lw=3,
653             transform=trans_data2)
654
655    ax1.set_xlim(2, 5)
656    ax1.set_ylim(0, 4)
657
658
659def test_image_preserve_size():
660    buff = io.BytesIO()
661
662    im = np.zeros((481, 321))
663    plt.imsave(buff, im, format="png")
664
665    buff.seek(0)
666    img = plt.imread(buff)
667
668    assert img.shape[:2] == im.shape
669
670
671def test_image_preserve_size2():
672    n = 7
673    data = np.identity(n, float)
674
675    fig = plt.figure(figsize=(n, n), frameon=False)
676
677    ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
678    ax.set_axis_off()
679    fig.add_axes(ax)
680    ax.imshow(data, interpolation='nearest', origin='lower',aspect='auto')
681    buff = io.BytesIO()
682    fig.savefig(buff, dpi=1)
683
684    buff.seek(0)
685    img = plt.imread(buff)
686
687    assert img.shape == (7, 7, 4)
688
689    assert_array_equal(np.asarray(img[:, :, 0], bool),
690                       np.identity(n, bool)[::-1])
691
692
693@image_comparison(baseline_images=['mask_image_over_under'],
694                  remove_text=True, extensions=['png'])
695def test_mask_image_over_under():
696    delta = 0.025
697    x = y = np.arange(-3.0, 3.0, delta)
698    X, Y = np.meshgrid(x, y)
699    Z1 = np.exp(-(X**2 + Y**2) / 2) / (2 * np.pi)
700    Z2 = (np.exp(-(((X - 1) / 1.5)**2 + ((Y - 1) / 0.5)**2) / 2) /
701          (2 * np.pi * 0.5 * 1.5))
702    Z = 10*(Z2 - Z1)  # difference of Gaussians
703
704    palette = copy(plt.cm.gray)
705    palette.set_over('r', 1.0)
706    palette.set_under('g', 1.0)
707    palette.set_bad('b', 1.0)
708    Zm = ma.masked_where(Z > 1.2, Z)
709    fig, (ax1, ax2) = plt.subplots(1, 2)
710    im = ax1.imshow(Zm, interpolation='bilinear',
711                    cmap=palette,
712                    norm=colors.Normalize(vmin=-1.0, vmax=1.0, clip=False),
713                    origin='lower', extent=[-3, 3, -3, 3])
714    ax1.set_title('Green=low, Red=high, Blue=bad')
715    fig.colorbar(im, extend='both', orientation='horizontal',
716                 ax=ax1, aspect=10)
717
718    im = ax2.imshow(Zm, interpolation='nearest',
719                    cmap=palette,
720                    norm=colors.BoundaryNorm([-1, -0.5, -0.2, 0, 0.2, 0.5, 1],
721                                             ncolors=256, clip=False),
722                    origin='lower', extent=[-3, 3, -3, 3])
723    ax2.set_title('With BoundaryNorm')
724    fig.colorbar(im, extend='both', spacing='proportional',
725                 orientation='horizontal', ax=ax2, aspect=10)
726
727
728@image_comparison(baseline_images=['mask_image'],
729                  remove_text=True)
730def test_mask_image():
731    # Test mask image two ways: Using nans and using a masked array.
732
733    fig, (ax1, ax2) = plt.subplots(1, 2)
734
735    A = np.ones((5, 5))
736    A[1:2, 1:2] = np.nan
737
738    ax1.imshow(A, interpolation='nearest')
739
740    A = np.zeros((5, 5), dtype=bool)
741    A[1:2, 1:2] = True
742    A = np.ma.masked_array(np.ones((5, 5), dtype=np.uint16), A)
743
744    ax2.imshow(A, interpolation='nearest')
745
746
747@image_comparison(baseline_images=['imshow_endianess'],
748                  remove_text=True, extensions=['png'])
749def test_imshow_endianess():
750    x = np.arange(10)
751    X, Y = np.meshgrid(x, x)
752    Z = ((X-5)**2 + (Y-5)**2)**0.5
753
754    fig, (ax1, ax2) = plt.subplots(1, 2)
755
756    kwargs = dict(origin="lower", interpolation='nearest',
757                  cmap='viridis')
758
759    ax1.imshow(Z.astype('<f8'), **kwargs)
760    ax2.imshow(Z.astype('>f8'), **kwargs)
761
762
763@image_comparison(baseline_images=['imshow_masked_interpolation'],
764                  remove_text=True, style='mpl20')
765def test_imshow_masked_interpolation():
766
767    cm = copy(plt.get_cmap('viridis'))
768    cm.set_over('r')
769    cm.set_under('b')
770    cm.set_bad('k')
771
772    N = 20
773    n = colors.Normalize(vmin=0, vmax=N*N-1)
774
775    # data = np.random.random((N, N))*N*N
776    data = np.arange(N*N, dtype='float').reshape(N, N)
777
778    data[5, 5] = -1
779    # This will cause crazy ringing for the higher-order
780    # interpolations
781    data[15, 5] = 1e5
782
783    # data[3, 3] = np.nan
784
785    data[15, 15] = np.inf
786
787    mask = np.zeros_like(data).astype('bool')
788    mask[5, 15] = True
789
790    data = np.ma.masked_array(data, mask)
791
792    fig, ax_grid = plt.subplots(3, 6)
793
794    for interp, ax in zip(sorted(mimage._interpd_), ax_grid.ravel()):
795        ax.set_title(interp)
796        ax.imshow(data, norm=n, cmap=cm, interpolation=interp)
797        ax.axis('off')
798
799
800def test_imshow_no_warn_invalid():
801    with warnings.catch_warnings(record=True) as warns:
802        warnings.simplefilter("always")
803        plt.imshow([[1, 2], [3, np.nan]])
804    assert len(warns) == 0
805
806
807@pytest.mark.parametrize(
808    'dtype', [np.dtype(s) for s in 'u2 u4 i2 i4 i8 f4 f8'.split()])
809def test_imshow_clips_rgb_to_valid_range(dtype):
810    arr = np.arange(300, dtype=dtype).reshape((10, 10, 3))
811    if dtype.kind != 'u':
812        arr -= 10
813    too_low = arr < 0
814    too_high = arr > 255
815    if dtype.kind == 'f':
816        arr = arr / 255
817    _, ax = plt.subplots()
818    out = ax.imshow(arr).get_array()
819    assert (out[too_low] == 0).all()
820    if dtype.kind == 'f':
821        assert (out[too_high] == 1).all()
822        assert out.dtype.kind == 'f'
823    else:
824        assert (out[too_high] == 255).all()
825        assert out.dtype == np.uint8
826
827
828@image_comparison(baseline_images=['imshow_flatfield'],
829                  remove_text=True, style='mpl20',
830                  extensions=['png'])
831def test_imshow_flatfield():
832    fig, ax = plt.subplots()
833    im = ax.imshow(np.ones((5, 5)))
834    im.set_clim(.5, 1.5)
835
836
837@image_comparison(baseline_images=['imshow_bignumbers'],
838                  remove_text=True, style='mpl20',
839                  extensions=['png'])
840def test_imshow_bignumbers():
841    # putting a big number in an array of integers shouldn't
842    # ruin the dynamic range of the resolved bits.
843    fig, ax = plt.subplots()
844    img = np.array([[1, 2, 1e12],[3, 1, 4]], dtype=np.uint64)
845    pc = ax.imshow(img)
846    pc.set_clim(0, 5)
847
848
849@image_comparison(baseline_images=['imshow_bignumbers_real'],
850                  remove_text=True, style='mpl20',
851                  extensions=['png'])
852def test_imshow_bignumbers_real():
853    # putting a big number in an array of integers shouldn't
854    # ruin the dynamic range of the resolved bits.
855    fig, ax = plt.subplots()
856    img = np.array([[2., 1., 1.e22],[4., 1., 3.]])
857    pc = ax.imshow(img)
858    pc.set_clim(0, 5)
859
860
861@pytest.mark.parametrize(
862    "make_norm",
863    [colors.Normalize,
864     colors.LogNorm,
865     lambda: colors.SymLogNorm(1),
866     lambda: colors.PowerNorm(1)])
867def test_empty_imshow(make_norm):
868    fig, ax = plt.subplots()
869    with warnings.catch_warnings():
870        warnings.filterwarnings(
871            "ignore", "Attempting to set identical left==right")
872        im = ax.imshow([[]], norm=make_norm())
873    im.set_extent([-5, 5, -5, 5])
874    fig.canvas.draw()
875
876    with pytest.raises(RuntimeError):
877        im.make_image(fig._cachedRenderer)
878
879
880def test_imshow_float128():
881    fig, ax = plt.subplots()
882    ax.imshow(np.zeros((3, 3), dtype=np.longdouble))
883
884
885def test_imshow_bool():
886    fig, ax = plt.subplots()
887    ax.imshow(np.array([[True, False], [False, True]], dtype=bool))
888
889
890def test_imshow_deprecated_interd_warn():
891    im = plt.imshow([[1, 2], [3, np.nan]])
892    for k in ('_interpd', '_interpdr', 'iterpnames'):
893        with warnings.catch_warnings(record=True) as warns:
894            getattr(im, k)
895        assert len(warns) == 1
896
897
898def test_full_invalid():
899    x = np.ones((10, 10))
900    x[:] = np.nan
901
902    f, ax = plt.subplots()
903    ax.imshow(x)
904
905    f.canvas.draw()
906
907
908@pytest.mark.parametrize("fmt,counted",
909                         [("ps", b" colorimage"), ("svg", b"<image")])
910@pytest.mark.parametrize("composite_image,count", [(True, 1), (False, 2)])
911def test_composite(fmt, counted, composite_image, count):
912    # Test that figures can be saved with and without combining multiple images
913    # (on a single set of axes) into a single composite image.
914    X, Y = np.meshgrid(np.arange(-5, 5, 1), np.arange(-5, 5, 1))
915    Z = np.sin(Y ** 2)
916
917    fig, ax = plt.subplots()
918    ax.set_xlim(0, 3)
919    ax.imshow(Z, extent=[0, 1, 0, 1])
920    ax.imshow(Z[::-1], extent=[2, 3, 0, 1])
921    plt.rcParams['image.composite_image'] = composite_image
922    buf = io.BytesIO()
923    fig.savefig(buf, format=fmt)
924    assert buf.getvalue().count(counted) == count
925
926
927def test_relim():
928    fig, ax = plt.subplots()
929    ax.imshow([[0]], extent=(0, 1, 0, 1))
930    ax.relim()
931    ax.autoscale()
932    assert ax.get_xlim() == ax.get_ylim() == (0, 1)
933