1"""Unit test for raster.py"""
2from .. import raster
3import unittest
4import numpy as np
5import pandas as pd
6from xarray import DataArray
7
8
9class Testraster(unittest.TestCase):
10    def setUp(self):
11        self.da1 = raster.testDataArray()
12        self.da2 = raster.testDataArray((1, 4, 4), missing_vals=False)
13        self.da3 = self.da2.rename(
14            {"band": "layer", "x": "longitude", "y": "latitude"})
15        self.data1 = pd.Series(np.ones(5))
16        self.da4 = raster.testDataArray((1, 1), missing_vals=False)
17        self.da4.data = np.array([["test"]])
18
19    def test_da2W(self):
20        w1 = raster.da2W(self.da1, "queen", k=2, n_jobs=-1)
21        self.assertEqual(w1[(1, -30.0, -180.0)],
22                         {(1, -90.0, 60.0): 1, (1, -90.0, -60.0): 1})
23        self.assertEqual(w1[(1, -30.0, 180.0)],
24                         {(1, -90.0, -60.0): 1, (1, -90.0, 60.0): 1})
25        self.assertEqual(w1.n, 5)
26        self.assertEqual(w1.index.names, self.da1.to_series().index.names)
27        self.assertEqual(w1.index.tolist()[0], (1, 90.0, 180.0))
28        self.assertEqual(w1.index.tolist()[1], (1, -30.0, -180.0))
29        self.assertEqual(w1.index.tolist()[2], (1, -30.0, 180.0))
30        self.assertEqual(w1.index.tolist()[3], (1, -90.0, -60.0))
31        w2 = raster.da2W(self.da2, "rook")
32        self.assertEqual(
33            sorted(w2.neighbors[(1, -90.0, 180.0)]), [(1, -90.0, 60.0), (1, -30.0, 180.0)])
34        self.assertEqual(sorted(w2.neighbors[(
35            1, -90.0, 60.0)]), [(1, -90.0, -60.0), (1, -90.0, 180.0), (1, -30.0, 60.0)])
36        self.assertEqual(w2.n, 16)
37        self.assertEqual(w2.index.names, self.da2.to_series().index.names)
38        self.assertEqual(w2.index.tolist(),
39                         self.da2.to_series().index.tolist())
40        coords_labels = {
41            "z_label": "layer",
42            "y_label": "latitude",
43            "x_label": "longitude",
44        }
45        w3 = raster.da2W(self.da3, z_value=1, coords_labels=coords_labels)
46        self.assertEqual(sorted(w3.neighbors[(
47            1, -90.0, 180.0)]), [(1, -90.0, 60.0), (1, -30.0, 60.0), (1, -30.0, 180.0)])
48        self.assertEqual(w3.n, 16)
49        self.assertEqual(w3.index.names, self.da3.to_series().index.names)
50        self.assertEqual(w3.index.tolist(),
51                         self.da3.to_series().index.tolist())
52
53    def test_da2WSP(self):
54        w1 = raster.da2WSP(self.da1, "rook", n_jobs=-1)
55        rows, cols = w1.sparse.shape
56        n = rows * cols
57        pct_nonzero = w1.sparse.nnz / float(n)
58        self.assertEqual(pct_nonzero, 0.08)
59        data = w1.sparse.todense().tolist()
60        self.assertEqual(data[3], [0, 0, 0, 0, 1])
61        self.assertEqual(data[4], [0, 0, 0, 1, 0])
62        self.assertEqual(w1.index.names, self.da1.to_series().index.names)
63        self.assertEqual(w1.index.tolist()[0], (1, 90.0, 180.0))
64        self.assertEqual(w1.index.tolist()[1], (1, -30.0, -180.0))
65        self.assertEqual(w1.index.tolist()[2], (1, -30.0, 180.0))
66        self.assertEqual(w1.index.tolist()[3], (1, -90.0, -60.0))
67        w2 = raster.da2WSP(self.da2, "queen", k=2, include_nodata=True)
68        w3 = raster.da2WSP(self.da2, "queen", k=2, n_jobs=-1)
69        self.assertEqual(w2.sparse.nnz, w3.sparse.nnz)
70        self.assertEqual(w2.sparse.todense().tolist(),
71                         w3.sparse.todense().tolist())
72        self.assertEqual(w2.n, 16)
73        self.assertEqual(w2.index.names, self.da2.to_series().index.names)
74        self.assertEqual(w2.index.tolist(),
75                         self.da2.to_series().index.tolist())
76
77    def test_w2da(self):
78        w2 = raster.da2W(self.da2, "rook", n_jobs=-1)
79        da2 = raster.w2da(self.da2.data.flatten(), w2,
80                          self.da2.attrs, self.da2.coords)
81        da_compare = DataArray.equals(da2, self.da2)
82        self.assertEqual(da_compare, True)
83
84    def test_wsp2da(self):
85        wsp1 = raster.da2WSP(self.da1, "queen")
86        da1 = raster.wsp2da(self.data1, wsp1)
87        self.assertEqual(da1["y"].values.tolist(),
88                         self.da1["y"].values.tolist())
89        self.assertEqual(da1["x"].values.tolist(),
90                         self.da1["x"].values.tolist())
91        self.assertEqual(da1.shape, (1, 4, 4))
92
93    def test_da_checker(self):
94        self.assertRaises(ValueError, raster.da2W, self.da4)
95
96
97suite = unittest.TestLoader().loadTestsFromTestCase(Testraster)
98
99if __name__ == "__main__":
100    runner = unittest.TextTestRunner()
101    runner.run(suite)
102