1# Copyright (c) 2016,2017 MetPy Developers.
2# Distributed under the terms of the BSD 3-Clause License.
3# SPDX-License-Identifier: BSD-3-Clause
4"""Tests for the `station_plot` module."""
5
6import matplotlib
7import matplotlib.pyplot as plt
8import numpy as np
9import pandas as pd
10import pytest
11
12from metpy.plots import (current_weather, high_clouds, nws_layout, simple_layout, sky_cover,
13                         StationPlot, StationPlotLayout)
14from metpy.units import units
15
16MPL_VERSION = matplotlib.__version__[:3]
17
18
19@pytest.mark.mpl_image_compare(tolerance=2.444, savefig_kwargs={'dpi': 300}, remove_text=True)
20def test_stationplot_api():
21    """Test the StationPlot API."""
22    fig = plt.figure(figsize=(9, 9))
23
24    # testing data
25    x = np.array([1, 5])
26    y = np.array([2, 4])
27
28    # Make the plot
29    sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=16)
30    sp.plot_barb([20, 0], [0, -50])
31    sp.plot_text('E', ['KOKC', 'ICT'], color='blue')
32    sp.plot_parameter('NW', [10.5, 15] * units.degC, color='red')
33    sp.plot_symbol('S', [5, 7], high_clouds, color='green')
34
35    sp.ax.set_xlim(0, 6)
36    sp.ax.set_ylim(0, 6)
37
38    return fig
39
40
41@pytest.mark.mpl_image_compare(tolerance=1.976, savefig_kwargs={'dpi': 300}, remove_text=True)
42def test_stationplot_clipping():
43    """Test the that clipping can be enabled as a default parameter."""
44    fig = plt.figure(figsize=(9, 9))
45
46    # testing data
47    x = np.array([1, 5])
48    y = np.array([2, 4])
49
50    # Make the plot
51    sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=16, clip_on=True)
52    sp.plot_barb([20, 0], [0, -50])
53    sp.plot_text('E', ['KOKC', 'ICT'], color='blue')
54    sp.plot_parameter('NW', [10.5, 15] * units.degC, color='red')
55    sp.plot_symbol('S', [5, 7], high_clouds, color='green')
56
57    sp.ax.set_xlim(1, 5)
58    sp.ax.set_ylim(1.75, 4.25)
59
60    return fig
61
62
63@pytest.mark.mpl_image_compare(tolerance=0.25, savefig_kwargs={'dpi': 300}, remove_text=True)
64def test_station_plot_replace():
65    """Test that locations are properly replaced."""
66    fig = plt.figure(figsize=(3, 3))
67
68    # testing data
69    x = np.array([1])
70    y = np.array([1])
71
72    # Make the plot
73    sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=16)
74    sp.plot_barb([20], [0])
75    sp.plot_barb([5], [0])
76    sp.plot_parameter('NW', [10.5], color='red')
77    sp.plot_parameter('NW', [20], color='blue')
78
79    sp.ax.set_xlim(-3, 3)
80    sp.ax.set_ylim(-3, 3)
81
82    return fig
83
84
85@pytest.mark.mpl_image_compare(tolerance=0.25, savefig_kwargs={'dpi': 300}, remove_text=True)
86def test_station_plot_locations():
87    """Test that locations are properly replaced."""
88    fig = plt.figure(figsize=(3, 3))
89
90    locations = ['C', 'N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW', 'N2', 'NNE', 'ENE', 'E2',
91                 'ESE', 'SSE', 'S2', 'SSW', 'WSW', 'W2', 'WNW', 'NNW']
92    x_pos = np.array([0])
93    y_pos = np.array([0])
94
95    # Make the plot
96    sp = StationPlot(fig.add_subplot(1, 1, 1), x_pos, y_pos, fontsize=8, spacing=24)
97    for loc in locations:
98        sp.plot_text(loc, [loc])
99
100    sp.ax.set_xlim(-2, 2)
101    sp.ax.set_ylim(-2, 2)
102
103    return fig
104
105
106@pytest.mark.mpl_image_compare(tolerance=0.00413, savefig_kwargs={'dpi': 300},
107                               remove_text=True)
108def test_stationlayout_api():
109    """Test the StationPlot API."""
110    fig = plt.figure(figsize=(9, 9))
111
112    # testing data
113    x = np.array([1, 5])
114    y = np.array([2, 4])
115    data = {'temp': np.array([33., 212.]) * units.degF, 'u': np.array([2, 0]) * units.knots,
116            'v': np.array([0, 5]) * units.knots, 'stid': ['KDEN', 'KSHV'], 'cover': [3, 8]}
117
118    # Set up the layout
119    layout = StationPlotLayout()
120    layout.add_barb('u', 'v', units='knots')
121    layout.add_value('NW', 'temp', fmt='0.1f', units=units.degC, color='darkred')
122    layout.add_symbol('C', 'cover', sky_cover, color='magenta')
123    layout.add_text((0, 2), 'stid', color='darkgrey')
124    layout.add_value('NE', 'dewpt', color='green')  # This should be ignored
125
126    # Make the plot
127    sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=12)
128    layout.plot(sp, data)
129
130    sp.ax.set_xlim(0, 6)
131    sp.ax.set_ylim(0, 6)
132
133    return fig
134
135
136def test_station_layout_odd_data():
137    """Test more corner cases with data passed in."""
138    fig = plt.figure(figsize=(9, 9))
139
140    # Set up test layout
141    layout = StationPlotLayout()
142    layout.add_barb('u', 'v')
143    layout.add_value('W', 'temperature', units='degF')
144
145    # Now only use data without wind and no units
146    data = {'temperature': [25.]}
147
148    # Make the plot
149    sp = StationPlot(fig.add_subplot(1, 1, 1), [1], [2], fontsize=12)
150    layout.plot(sp, data)
151    assert True
152
153
154def test_station_layout_replace():
155    """Test that layout locations are replaced."""
156    layout = StationPlotLayout()
157    layout.add_text('E', 'temperature')
158    layout.add_value('E', 'dewpoint')
159    assert 'E' in layout
160    assert layout['E'][0] is StationPlotLayout.PlotTypes.value
161    assert layout['E'][1] == 'dewpoint'
162
163
164def test_station_layout_names():
165    """Test getting station layout names."""
166    layout = StationPlotLayout()
167    layout.add_barb('u', 'v')
168    layout.add_text('E', 'stid')
169    layout.add_value('W', 'temp')
170    layout.add_symbol('C', 'cover', lambda x: x)
171    assert sorted(layout.names()) == ['cover', 'stid', 'temp', 'u', 'v']
172
173
174@pytest.mark.mpl_image_compare(tolerance=0, savefig_kwargs={'dpi': 300}, remove_text=True)
175def test_simple_layout():
176    """Test metpy's simple layout for station plots."""
177    fig = plt.figure(figsize=(9, 9))
178
179    # testing data
180    x = np.array([1, 5])
181    y = np.array([2, 4])
182    data = {'air_temperature': np.array([33., 212.]) * units.degF,
183            'dew_point_temperature': np.array([28., 80.]) * units.degF,
184            'air_pressure_at_sea_level': np.array([29.92, 28.00]) * units.inHg,
185            'eastward_wind': np.array([2, 0]) * units.knots,
186            'northward_wind': np.array([0, 5]) * units.knots, 'cloud_coverage': [3, 8],
187            'current_wx1_symbol': [65, 75], 'unused': [1, 2]}
188
189    # Make the plot
190    sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=12)
191    simple_layout.plot(sp, data)
192
193    sp.ax.set_xlim(0, 6)
194    sp.ax.set_ylim(0, 6)
195
196    return fig
197
198
199@pytest.mark.mpl_image_compare(tolerance=0.1848, savefig_kwargs={'dpi': 300}, remove_text=True)
200def test_nws_layout():
201    """Test metpy's NWS layout for station plots."""
202    fig = plt.figure(figsize=(3, 3))
203
204    # testing data
205    x = np.array([1])
206    y = np.array([2])
207    data = {'air_temperature': np.array([77]) * units.degF,
208            'dew_point_temperature': np.array([71]) * units.degF,
209            'air_pressure_at_sea_level': np.array([999.8]) * units('mbar'),
210            'eastward_wind': np.array([15.]) * units.knots,
211            'northward_wind': np.array([15.]) * units.knots, 'cloud_coverage': [7],
212            'current_wx1_symbol': [80], 'high_cloud_type': [1], 'medium_cloud_type': [3],
213            'low_cloud_type': [2], 'visibility_in_air': np.array([5.]) * units.mile,
214            'tendency_of_air_pressure': np.array([-0.3]) * units('mbar'),
215            'tendency_of_air_pressure_symbol': [8]}
216
217    # Make the plot
218    sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=12, spacing=16)
219    nws_layout.plot(sp, data)
220
221    sp.ax.set_xlim(0, 3)
222    sp.ax.set_ylim(0, 3)
223
224    return fig
225
226
227@pytest.mark.mpl_image_compare(tolerance=1.05, remove_text=True)
228def test_plot_text_fontsize():
229    """Test changing fontsize in plot_text."""
230    fig = plt.figure(figsize=(3, 3))
231    ax = plt.subplot(1, 1, 1)
232
233    # testing data
234    x = np.array([1])
235    y = np.array([2])
236
237    # Make the plot
238    sp = StationPlot(ax, x, y, fontsize=36)
239    sp.plot_text('NW', ['72'], fontsize=24)
240    sp.plot_text('SW', ['60'], fontsize=4)
241
242    sp.ax.set_xlim(0, 3)
243    sp.ax.set_ylim(0, 3)
244
245    return fig
246
247
248@pytest.mark.mpl_image_compare(tolerance=1.05, remove_text=True)
249def test_plot_symbol_fontsize():
250    """Test changing fontsize in plotting of symbols."""
251    fig = plt.figure(figsize=(3, 3))
252    ax = plt.subplot(1, 1, 1)
253
254    sp = StationPlot(ax, [0], [0], fontsize=8, spacing=32)
255    sp.plot_symbol('E', [92], current_weather)
256    sp.plot_symbol('W', [96], current_weather, fontsize=100)
257
258    return fig
259
260
261def test_layout_str():
262    """Test layout string representation."""
263    layout = StationPlotLayout()
264    layout.add_barb('u', 'v')
265    layout.add_text('E', 'stid')
266    layout.add_value('W', 'temp')
267    layout.add_symbol('C', 'cover', lambda x: x)
268    assert str(layout) == ('{C: (symbol, cover, ...), E: (text, stid, ...), '
269                           "W: (value, temp, ...), barb: (barb, ('u', 'v'), ...)}")
270
271
272@pytest.fixture
273def wind_plot():
274    """Create southerly wind test data."""
275    v = np.full((5, 5), 10, dtype=np.float64)
276    u = np.zeros_like(v)
277    x, y = np.meshgrid(np.linspace(-120, -60, 5), np.linspace(25, 50, 5))
278    return u, v, x, y
279
280
281@pytest.mark.mpl_image_compare(tolerance={'3.0': 0.04231}.get(MPL_VERSION, 0.00434),
282                               remove_text=True)
283def test_barb_projection(wind_plot, ccrs):
284    """Test that barbs are properly projected (#598)."""
285    u, v, x, y = wind_plot
286
287    # Plot and check barbs (they should align with grid lines)
288    fig = plt.figure()
289    ax = fig.add_subplot(1, 1, 1, projection=ccrs.LambertConformal())
290    ax.gridlines(xlocs=[-120, -105, -90, -75, -60], ylocs=np.arange(24, 55, 6))
291    sp = StationPlot(ax, x, y, transform=ccrs.PlateCarree())
292    sp.plot_barb(u, v)
293
294    return fig
295
296
297@pytest.mark.mpl_image_compare(tolerance={'3.0': 0.0693}.get(MPL_VERSION, 0.00382),
298                               remove_text=True)
299def test_arrow_projection(wind_plot, ccrs):
300    """Test that arrows are properly projected."""
301    u, v, x, y = wind_plot
302
303    # Plot and check barbs (they should align with grid lines)
304    fig = plt.figure()
305    ax = fig.add_subplot(1, 1, 1, projection=ccrs.LambertConformal())
306    ax.gridlines(xlocs=[-120, -105, -90, -75, -60], ylocs=np.arange(24, 55, 6))
307    sp = StationPlot(ax, x, y, transform=ccrs.PlateCarree())
308    sp.plot_arrow(u, v)
309    sp.plot_arrow(u, v)  # plot_arrow used twice to hit removal if statement
310
311    return fig
312
313
314@pytest.fixture
315def wind_projection_list():
316    """Create wind lists for testing."""
317    lat = [38.22, 38.18, 38.25]
318    lon = [-85.76, -85.86, -85.77]
319    u = [1.89778964, -3.83776523, 3.64147732] * units('m/s')
320    v = [1.93480072, 1.31000184, 1.36075552] * units('m/s')
321    return lat, lon, u, v
322
323
324def test_barb_projection_list(wind_projection_list):
325    """Test that barbs will be projected when lat/lon lists are provided."""
326    lat, lon, u, v = wind_projection_list
327
328    fig = plt.figure()
329    ax = fig.add_subplot(1, 1, 1)
330    stnplot = StationPlot(ax, lon, lat)
331    stnplot.plot_barb(u, v)
332    assert stnplot.barbs
333
334
335def test_arrow_projection_list(wind_projection_list):
336    """Test that arrows will be projected when lat/lon lists are provided."""
337    lat, lon, u, v = wind_projection_list
338
339    fig = plt.figure()
340    ax = fig.add_subplot(1, 1, 1)
341    stnplot = StationPlot(ax, lon, lat)
342    stnplot.plot_arrow(u, v)
343    assert stnplot.arrows
344
345
346@pytest.fixture
347def barbs_units():
348    """Create barbs with units for testing."""
349    x_pos = np.array([0])
350    y_pos = np.array([0])
351    u_wind = np.array([3.63767155210412]) * units('m/s')
352    v_wind = np.array([3.63767155210412]) * units('m/s')
353    return x_pos, y_pos, u_wind, v_wind
354
355
356@pytest.mark.mpl_image_compare(tolerance=0.0048, remove_text=True)
357def test_barb_unit_conversion(barbs_units):
358    """Test that barbs units can be converted at plot time (#737)."""
359    x_pos, y_pos, u_wind, v_wind = barbs_units
360
361    fig = plt.figure()
362    ax = fig.add_subplot(1, 1, 1)
363    stnplot = StationPlot(ax, x_pos, y_pos)
364    stnplot.plot_barb(u_wind, v_wind, plot_units='knots')
365    ax.set_xlim(-5, 5)
366    ax.set_ylim(-5, 5)
367
368    return fig
369
370
371@pytest.mark.mpl_image_compare(tolerance=0.0048, remove_text=True)
372def test_arrow_unit_conversion(barbs_units):
373    """Test that arrow units can be converted at plot time (#737)."""
374    x_pos, y_pos, u_wind, v_wind = barbs_units
375
376    fig = plt.figure()
377    ax = fig.add_subplot(1, 1, 1)
378    stnplot = StationPlot(ax, x_pos, y_pos)
379    stnplot.plot_arrow(u_wind, v_wind, plot_units='knots')
380    ax.set_xlim(-5, 5)
381    ax.set_ylim(-5, 5)
382
383    return fig
384
385
386@pytest.mark.mpl_image_compare(tolerance=0.0048, remove_text=True)
387def test_barb_no_default_unit_conversion():
388    """Test that barbs units are left alone by default (#737)."""
389    x_pos = np.array([0])
390    y_pos = np.array([0])
391    u_wind = np.array([3.63767155210412]) * units('m/s')
392    v_wind = np.array([3.63767155210412]) * units('m/s')
393
394    fig = plt.figure()
395    ax = fig.add_subplot(1, 1, 1)
396    stnplot = StationPlot(ax, x_pos, y_pos)
397    stnplot.plot_barb(u_wind, v_wind)
398    ax.set_xlim(-5, 5)
399    ax.set_ylim(-5, 5)
400
401    return fig
402
403
404@pytest.mark.parametrize('u,v', [(np.array([3]) * units('m/s'), np.array([3])),
405                                 (np.array([3]), np.array([3]) * units('m/s'))])
406def test_barb_unit_conversion_exception(u, v):
407    """Test that errors are raise if unit conversion is requested on un-united data."""
408    x_pos = np.array([0])
409    y_pos = np.array([0])
410
411    fig = plt.figure()
412    ax = fig.add_subplot(1, 1, 1)
413    stnplot = StationPlot(ax, x_pos, y_pos)
414    with pytest.raises(ValueError):
415        stnplot.plot_barb(u, v, plot_units='knots')
416
417
418@pytest.mark.mpl_image_compare(tolerance=0.021, savefig_kwargs={'dpi': 300}, remove_text=True)
419def test_symbol_pandas_timeseries():
420    """Test the usage of Pandas DatetimeIndex as a valid `x` input into StationPlot."""
421    pd.plotting.register_matplotlib_converters()
422    rng = pd.date_range('12/1/2017', periods=5, freq='D')
423    sc = [1, 2, 3, 4, 5]
424    ts = pd.Series(sc, index=rng)
425    fig, ax = plt.subplots()
426    y = np.ones(len(ts.index))
427    stationplot = StationPlot(ax, ts.index, y, fontsize=12)
428    stationplot.plot_symbol('C', ts, sky_cover)
429    ax.xaxis.set_major_locator(matplotlib.dates.DayLocator())
430    ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%-d'))
431
432    return fig
433
434
435@pytest.mark.mpl_image_compare(tolerance=2.444, savefig_kwargs={'dpi': 300}, remove_text=True)
436def test_stationplot_unit_conversion():
437    """Test the StationPlot API."""
438    fig = plt.figure(figsize=(9, 9))
439
440    # testing data
441    x = np.array([1, 5])
442    y = np.array([2, 4])
443
444    # Make the plot
445    sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=16)
446    sp.plot_barb([20, 0], [0, -50])
447    sp.plot_text('E', ['KOKC', 'ICT'], color='blue')
448    sp.plot_parameter('NW', [10.5, 15] * units.degC, plot_units='degF', color='red')
449    sp.plot_symbol('S', [5, 7], high_clouds, color='green')
450
451    sp.ax.set_xlim(0, 6)
452    sp.ax.set_ylim(0, 6)
453
454    return fig
455
456
457def test_scalar_unit_conversion_exception():
458    """Test that errors are raise if unit conversion is requested on un-united data."""
459    x_pos = np.array([0])
460    y_pos = np.array([0])
461
462    fig = plt.figure()
463    ax = fig.add_subplot(1, 1, 1)
464    stnplot = StationPlot(ax, x_pos, y_pos)
465    with pytest.raises(ValueError):
466        stnplot.plot_parameter('C', 50, plot_units='degC')
467