1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Test suite for the network class.
5
6:copyright:
7    Lion Krischer (krischer@geophysik.uni-muenchen.de), 2013
8:license:
9    GNU Lesser General Public License, Version 3
10    (https://www.gnu.org/copyleft/lesser.html)
11"""
12from __future__ import (absolute_import, division, print_function,
13                        unicode_literals)
14from future.builtins import *  # NOQA
15
16import io
17import os
18import unittest
19import warnings
20
21import numpy as np
22from matplotlib import rcParams
23
24import obspy
25from obspy import UTCDateTime, read_inventory
26from obspy.core.compatibility import mock
27from obspy.core.util import (
28    BASEMAP_VERSION, CARTOPY_VERSION, MATPLOTLIB_VERSION, PROJ4_VERSION)
29from obspy.core.util.testing import ImageComparison
30from obspy.core.inventory import (Channel, Inventory, Network, Response,
31                                  Station)
32
33
34class NetworkTestCase(unittest.TestCase):
35    """
36    Tests for the :class:`~obspy.core.inventory.network.Network` class.
37    """
38    def setUp(self):
39        self.image_dir = os.path.join(os.path.dirname(__file__), 'images')
40        self.nperr = np.geterr()
41        np.seterr(all='ignore')
42
43    def tearDown(self):
44        np.seterr(**self.nperr)
45
46    def test_get_response(self):
47        response_n1_s1 = Response('RESPN1S1')
48        response_n1_s2 = Response('RESPN1S2')
49        response_n2_s1 = Response('RESPN2S1')
50        channels_n1_s1 = [Channel(code='BHZ',
51                                  location_code='',
52                                  latitude=0.0,
53                                  longitude=0.0,
54                                  elevation=0.0,
55                                  depth=0.0,
56                                  response=response_n1_s1)]
57        channels_n1_s2 = [Channel(code='BHZ',
58                                  location_code='',
59                                  latitude=0.0,
60                                  longitude=0.0,
61                                  elevation=0.0,
62                                  depth=0.0,
63                                  response=response_n1_s2)]
64        channels_n2_s1 = [Channel(code='BHZ',
65                                  location_code='',
66                                  latitude=0.0,
67                                  longitude=0.0,
68                                  elevation=0.0,
69                                  depth=0.0,
70                                  response=response_n2_s1)]
71        stations_1 = [Station(code='N1S1',
72                              latitude=0.0,
73                              longitude=0.0,
74                              elevation=0.0,
75                              channels=channels_n1_s1),
76                      Station(code='N1S2',
77                              latitude=0.0,
78                              longitude=0.0,
79                              elevation=0.0,
80                              channels=channels_n1_s2),
81                      Station(code='N2S1',
82                              latitude=0.0,
83                              longitude=0.0,
84                              elevation=0.0,
85                              channels=channels_n2_s1)]
86        network = Network('N1', stations=stations_1)
87
88        response = network.get_response('N1.N1S1..BHZ',
89                                        UTCDateTime('2010-01-01T12:00'))
90        self.assertEqual(response, response_n1_s1)
91        response = network.get_response('N1.N1S2..BHZ',
92                                        UTCDateTime('2010-01-01T12:00'))
93        self.assertEqual(response, response_n1_s2)
94        response = network.get_response('N1.N2S1..BHZ',
95                                        UTCDateTime('2010-01-01T12:00'))
96        self.assertEqual(response, response_n2_s1)
97
98    def test_get_coordinates(self):
99        """
100        Test extracting coordinates
101        """
102        expected = {u'latitude': 47.737166999999999,
103                    u'longitude': 12.795714,
104                    u'elevation': 860.0,
105                    u'local_depth': 0.0}
106        channels = [Channel(code='EHZ',
107                            location_code='',
108                            start_date=UTCDateTime('2007-01-01'),
109                            latitude=47.737166999999999,
110                            longitude=12.795714,
111                            elevation=860.0,
112                            depth=0.0)]
113        stations = [Station(code='RJOB',
114                            latitude=0.0,
115                            longitude=0.0,
116                            elevation=0.0,
117                            channels=channels)]
118        network = Network('BW', stations=stations)
119        # 1
120        coordinates = network.get_coordinates('BW.RJOB..EHZ',
121                                              UTCDateTime('2010-01-01T12:00'))
122        self.assertEqual(sorted(coordinates.items()), sorted(expected.items()))
123        # 2 - without datetime
124        coordinates = network.get_coordinates('BW.RJOB..EHZ')
125        self.assertEqual(sorted(coordinates.items()), sorted(expected.items()))
126        # 3 - unknown SEED ID should raise exception
127        self.assertRaises(Exception, network.get_coordinates, 'BW.RJOB..XXX')
128
129    def test_response_plot(self):
130        """
131        Tests the response plot.
132        """
133        # Bug in matplotlib 1.4.0 - 1.4.x:
134        # See https://github.com/matplotlib/matplotlib/issues/4012
135        reltol = 1.0
136        if [1, 4, 0] <= MATPLOTLIB_VERSION <= [1, 5, 0]:
137            reltol = 2.0
138
139        net = read_inventory()[0]
140        t = UTCDateTime(2008, 7, 1)
141        with warnings.catch_warnings(record=True):
142            warnings.simplefilter("ignore")
143            with ImageComparison(self.image_dir, "network_response.png",
144                                 reltol=reltol) as ic:
145                rcParams['savefig.dpi'] = 72
146                net.plot_response(0.002, output="DISP", channel="B*E",
147                                  time=t, outfile=ic.name)
148
149    def test_response_plot_epoch_times_in_label(self):
150        """
151        Tests response plot with epoch times in labels switched on.
152        """
153        import matplotlib.pyplot as plt
154        net = read_inventory().select(station='RJOB', channel='EHZ')[0]
155        with warnings.catch_warnings(record=True):
156            warnings.simplefilter("ignore")
157            fig = net.plot_response(0.01, label_epoch_dates=True, show=False)
158        try:
159            legend = fig.axes[0].get_legend()
160            texts = legend.get_texts()
161            expecteds = ['BW.RJOB..EHZ\n2001-05-15 -- 2006-12-12',
162                         'BW.RJOB..EHZ\n2006-12-13 -- 2007-12-17',
163                         'BW.RJOB..EHZ\n2007-12-17 -- open']
164            self.assertEqual(len(texts), 3)
165            for text, expected in zip(texts, expecteds):
166                self.assertEqual(text.get_text(), expected)
167        finally:
168            plt.close(fig)
169
170    def test_len(self):
171        """
172        Tests the __len__ property.
173        """
174        net = read_inventory()[0]
175        self.assertEqual(len(net), len(net.stations))
176        self.assertEqual(len(net), 2)
177
178    def test_network_select(self):
179        """
180        Test for the select() method of the network class.
181        """
182        net = read_inventory()[0]
183
184        # Basic asserts to assert some things about the test data.
185        self.assertEqual(len(net), 2)
186        self.assertEqual(len(net[0]), 12)
187        self.assertEqual(len(net[1]), 9)
188        self.assertEqual(sum(len(i) for i in net), 21)
189
190        # Artificially move the start time of the first station before the
191        # channel start times.
192        net[0].start_date = UTCDateTime(1999, 1, 1)
193
194        # Nothing happens if nothing is specified or if everything is selected.
195        self.assertEqual(sum(len(i) for i in net.select()), 21)
196        self.assertEqual(sum(len(i) for i in net.select(station="*")), 21)
197        self.assertEqual(sum(len(i) for i in net.select(location="*")), 21)
198        self.assertEqual(sum(len(i) for i in net.select(channel="*")), 21)
199        self.assertEqual(sum(len(i) for i in net.select(
200            station="*", location="*", channel="*")), 21)
201
202        # No matching station.
203        self.assertEqual(sum(len(i) for i in net.select(station="RR")), 0)
204        # keep_empty does not do anything in these cases.
205        self.assertEqual(sum(len(i) for i in
206                             net.select(station="RR", keep_empty=True)), 0)
207        # Selecting only one station.
208        self.assertEqual(sum(len(i) for i in
209                             net.select(station="FUR", keep_empty=True)), 12)
210        self.assertEqual(sum(len(i) for i in
211                             net.select(station="F*", keep_empty=True)), 12)
212        self.assertEqual(sum(len(i) for i in
213                             net.select(station="WET", keep_empty=True)), 9)
214        self.assertEqual(sum(len(i) for i in
215                             net.select(
216                                minlatitude=47.89, maxlatitude=48.39,
217                                minlongitude=10.88, maxlongitude=11.98)), 12)
218        self.assertEqual(sum(len(i) for i in
219                             net.select(
220                                latitude=48.12, longitude=12.24,
221                                maxradius=1)), 12)
222
223        # Test the keep_empty flag.
224        net_2 = net.select(time=UTCDateTime(2000, 1, 1))
225        self.assertEqual(len(net_2), 0)
226        self.assertEqual(sum(len(i) for i in net_2), 0)
227        # One is kept - it has no more channels but the station still has a
228        # valid start time.
229        net_2 = net.select(time=UTCDateTime(2000, 1, 1), keep_empty=True)
230        self.assertEqual(len(net_2), 1)
231        self.assertEqual(sum(len(i) for i in net_2), 0)
232
233        # location, channel, time, starttime, endtime, and sampling_rate
234        # and geographic parameters are also passed on to the station selector.
235        select_kwargs = {
236            "location": "00",
237            "channel": "EHE",
238            "time": UTCDateTime(2001, 1, 1),
239            "sampling_rate": 123.0,
240            "starttime": UTCDateTime(2002, 1, 1),
241            "endtime": UTCDateTime(2003, 1, 1),
242            "minlatitude": None,
243            "maxlatitude": None,
244            "minlongitude": None,
245            "maxlongitude": None,
246            "latitude": None,
247            "longitude": None,
248            "minradius": None,
249            "maxradius": None}
250
251        with mock.patch("obspy.core.inventory.station.Station.select") as p:
252            p.return_value = obspy.core.inventory.station.Station("FUR", 1,
253                                                                  2, 3)
254            net.select(**select_kwargs)
255
256        self.assertEqual(p.call_args[1], select_kwargs)
257
258    def test_writing_network_before_1990(self):
259        inv = obspy.Inventory(networks=[
260            Network(code="XX", start_date=obspy.UTCDateTime(1880, 1, 1))],
261            source="")
262        with io.BytesIO() as buf:
263            inv.write(buf, format="stationxml")
264            buf.seek(0, 0)
265            inv2 = read_inventory(buf)
266
267        self.assertEqual(inv.networks[0], inv2.networks[0])
268
269    def test_network_select_with_empty_stations(self):
270        """
271        Tests the behaviour of the Network.select() method for empty stations.
272        """
273        net = read_inventory()[0]
274
275        # Delete all channels.
276        for sta in net:
277            sta.channels = []
278
279        # 2 stations and 0 channels remain.
280        self.assertEqual(len(net), 2)
281        self.assertEqual(sum(len(sta) for sta in net), 0)
282
283        # No arguments, everything should be selected.
284        self.assertEqual(len(net.select()), 2)
285
286        # Everything selected, nothing should happen.
287        self.assertEqual(len(net.select(station="*")), 2)
288
289        # Only select a single station.
290        self.assertEqual(len(net.select(station="FUR")), 1)
291        self.assertEqual(len(net.select(station="FU?")), 1)
292        self.assertEqual(len(net.select(station="W?T")), 1)
293
294        # Once again, this time with the time selection.
295        self.assertEqual(len(net.select(time=UTCDateTime(2006, 1, 1))), 0)
296        self.assertEqual(len(net.select(time=UTCDateTime(2007, 1, 1))), 1)
297        self.assertEqual(len(net.select(time=UTCDateTime(2008, 1, 1))), 2)
298
299    def test_empty_network_code(self):
300        """
301        Tests that an empty sring is acceptabble.
302        """
303        # An empty string is allowed.
304        n = Network(code="")
305        self.assertEqual(n.code, "")
306
307        # But None is not allowed.
308        with self.assertRaises(ValueError) as e:
309            Network(code=None)
310        self.assertEqual(e.exception.args[0], "A code is required")
311
312        # Should still serialize to something.
313        inv = Inventory(networks=[n])
314        with io.BytesIO() as buf:
315            inv.write(buf, format="stationxml", validate=True)
316            buf.seek(0, 0)
317            inv2 = read_inventory(buf)
318
319        self.assertEqual(inv, inv2)
320
321
322@unittest.skipIf(not BASEMAP_VERSION, 'basemap not installed')
323@unittest.skipIf(
324    BASEMAP_VERSION >= [1, 1, 0] and MATPLOTLIB_VERSION == [3, 0, 1],
325    'matplotlib 3.0.1 is not compatible with basemap')
326class NetworkBasemapTestCase(unittest.TestCase):
327    """
328    Tests for the :meth:`~obspy.station.network.Network.plot` with Basemap.
329    """
330    def setUp(self):
331        self.image_dir = os.path.join(os.path.dirname(__file__), 'images')
332        self.nperr = np.geterr()
333        np.seterr(all='ignore')
334
335    def tearDown(self):
336        np.seterr(**self.nperr)
337
338    @unittest.skipIf(PROJ4_VERSION and PROJ4_VERSION[0] == 5,
339                     'unsupported proj4 library')
340    def test_location_plot_global(self):
341        """
342        Tests the network location preview plot, default parameters, using
343        Basemap.
344        """
345        net = read_inventory()[0]
346        reltol = 1.3
347        # Coordinate lines might be slightly off, depending on the basemap
348        # version.
349        if BASEMAP_VERSION < [1, 0, 7]:
350            reltol = 3.0
351        with ImageComparison(self.image_dir, 'network_location-basemap1.png',
352                             reltol=reltol) as ic:
353            rcParams['savefig.dpi'] = 72
354            net.plot(method='basemap', outfile=ic.name)
355
356    def test_location_plot_ortho(self):
357        """
358        Tests the network location preview plot, ortho projection, some
359        non-default parameters, using Basemap.
360        """
361        net = read_inventory()[0]
362        with ImageComparison(self.image_dir,
363                             'network_location-basemap2.png') as ic:
364            rcParams['savefig.dpi'] = 72
365            net.plot(method='basemap', projection='ortho', resolution='c',
366                     continent_fill_color='0.5', marker='d',
367                     color='yellow', label=False, outfile=ic.name)
368
369    def test_location_plot_local(self):
370        """
371        Tests the network location preview plot, local projection, some more
372        non-default parameters, using Basemap.
373        """
374        net = read_inventory()[0]
375        # Coordinate lines might be slightly off, depending on the basemap
376        # version.
377        reltol = 2.0
378        # Basemap smaller 1.0.4 has a serious issue with plotting. Thus the
379        # tolerance must be much higher.
380        if BASEMAP_VERSION < [1, 0, 4]:
381            reltol = 100.0
382        with ImageComparison(self.image_dir, 'network_location-basemap3.png',
383                             reltol=reltol) as ic:
384            rcParams['savefig.dpi'] = 72
385            net.plot(method='basemap', projection='local', resolution='l',
386                     size=13**2, outfile=ic.name)
387
388
389@unittest.skipIf(not (CARTOPY_VERSION and CARTOPY_VERSION >= [0, 12, 0]),
390                 'cartopy not installed')
391class NetworkCartopyTestCase(unittest.TestCase):
392    """
393    Tests for the :meth:`~obspy.station.network.Network.plot` with Cartopy.
394    """
395    def setUp(self):
396        self.image_dir = os.path.join(os.path.dirname(__file__), 'images')
397        self.nperr = np.geterr()
398        np.seterr(all='ignore')
399
400    def tearDown(self):
401        np.seterr(**self.nperr)
402
403    def test_location_plot_global(self):
404        """
405        Tests the network location preview plot, default parameters, using
406        Cartopy.
407        """
408        net = read_inventory()[0]
409        with ImageComparison(self.image_dir,
410                             'network_location-cartopy1.png') as ic:
411            rcParams['savefig.dpi'] = 72
412            net.plot(method='cartopy', outfile=ic.name)
413
414    def test_location_plot_ortho(self):
415        """
416        Tests the network location preview plot, ortho projection, some
417        non-default parameters, using Cartopy.
418        """
419        net = read_inventory()[0]
420        with ImageComparison(self.image_dir,
421                             'network_location-cartopy2.png') as ic:
422            rcParams['savefig.dpi'] = 72
423            net.plot(method='cartopy', projection='ortho', resolution='c',
424                     continent_fill_color='0.5', marker='d',
425                     color='yellow', label=False, outfile=ic.name)
426
427    def test_location_plot_local(self):
428        """
429        Tests the network location preview plot, local projection, some more
430        non-default parameters, using Cartopy.
431        """
432        net = read_inventory()[0]
433        with ImageComparison(self.image_dir,
434                             'network_location-cartopy3.png') as ic:
435            rcParams['savefig.dpi'] = 72
436            net.plot(method='cartopy', projection='local', resolution='50m',
437                     size=13**2, outfile=ic.name)
438
439
440def suite():
441    suite = unittest.TestSuite()
442    suite.addTest(unittest.makeSuite(NetworkTestCase, 'test'))
443    suite.addTest(unittest.makeSuite(NetworkBasemapTestCase, 'test'))
444    suite.addTest(unittest.makeSuite(NetworkCartopyTestCase, 'test'))
445    return suite
446
447
448if __name__ == '__main__':
449    unittest.main(defaultTest='suite')
450