1import unittest
2from .. import util
3import numpy as np
4
5
6class ShuffleMatrix_Tester(unittest.TestCase):
7    def setUp(self):
8        self.X = np.arange(16)
9        self.X.shape = (4, 4)
10
11    def test_shuffle_matrix(self):
12        np.random.seed(10)
13        obs = util.shuffle_matrix(self.X, list(range(4))).flatten().tolist()
14        exp = [10, 8, 11, 9, 2, 0, 3, 1, 14, 12, 15, 13, 6, 4, 7, 5]
15        for i in range(16):
16            self.assertEqual(exp[i], obs[i])
17
18
19class GetLower_Tester(unittest.TestCase):
20    def setUp(self):
21        self.X = np.arange(16)
22        self.X.shape = (4, 4)
23
24    def test_get_lower(self):
25        np.random.seed(10)
26        obs = util.get_lower(self.X).flatten().tolist()
27        exp = [4, 8, 9, 12, 13, 14]
28        for i in range(6):
29            self.assertEqual(exp[i], obs[i])
30
31
32class FillDiagonal_Tester(unittest.TestCase):
33    def setUp(self):
34        self.p3 = np.array([[0.5, 0.5, 0], [0.3, 0.7, 0], [0, 0, 0]])
35        self.p23 = np.array(
36            [
37                [[0.5, 0.5, 0], [0.3, 0.7, 0], [0, 0, 0]],
38                [[0, 0, 0], [0.3, 0.7, 0], [0, 0, 0]],
39            ]
40        )
41
42    # def test_fill_diag2(self):
43    #     obs = util.fill_empty_diagonal_2d(self.p3)
44    #     exp = np.array([[0.5, 0.5, 0. ], [0.3, 0.7, 0. ], [0. , 0. , 1. ]])
45    #     np.testing.assert_array_almost_equal(exp, obs)
46    #
47    #     with self.assertRaises(ValueError):
48    #         obs = util.fill_empty_diagonal_2d(self.p23)
49    #
50    # def test_fill_diag3(self):
51    #     obs = util.fill_empty_diagonal_3d(self.p23)
52    #     exp = np.array([[[0.5, 0.5, 0. ], [0.3, 0.7, 0. ], [0. , 0. , 1. ]],
53    #                     [[1. , 0. , 0. ], [0.3, 0.7, 0. ], [0. , 0. , 1. ]]])
54    #     np.testing.assert_array_almost_equal(exp, obs)
55    #
56    #     with self.assertRaises(ValueError):
57    #         obs = util.fill_empty_diagonal_3d(self.p3)
58
59    def test_fill_diag(self):
60        obs = util.fill_empty_diagonals(self.p3)
61        exp = np.array([[0.5, 0.5, 0.0], [0.3, 0.7, 0.0], [0.0, 0.0, 1.0]])
62        np.testing.assert_array_almost_equal(exp, obs)
63
64        obs = util.fill_empty_diagonals(self.p23)
65        exp = np.array(
66            [
67                [[0.5, 0.5, 0.0], [0.3, 0.7, 0.0], [0.0, 0.0, 1.0]],
68                [[1.0, 0.0, 0.0], [0.3, 0.7, 0.0], [0.0, 0.0, 1.0]],
69            ]
70        )
71        np.testing.assert_array_almost_equal(exp, obs)
72
73
74suite = unittest.TestSuite()
75test_classes = [ShuffleMatrix_Tester, GetLower_Tester]
76for i in test_classes:
77    a = unittest.TestLoader().loadTestsFromTestCase(i)
78    suite.addTest(a)
79
80if __name__ == "__main__":
81    runner = unittest.TextTestRunner()
82    runner.run(suite)
83