1# Copyright (c) 2010-2020, Manfred Moitzi
2# License: MIT License
3import pytest
4import pickle
5from math import radians, sin, cos, pi, isclose
6# Import from 'ezdxf.math._matrix44' to test Python implementation
7from ezdxf.math import close_vectors
8from ezdxf.math._matrix44 import Matrix44
9from ezdxf.acc import USE_C_EXT
10
11m44_classes = [Matrix44]
12
13if USE_C_EXT:
14    from ezdxf.acc.matrix44 import Matrix44 as CMatrix44
15
16    m44_classes.append(CMatrix44)
17
18
19@pytest.fixture(params=m44_classes)
20def m44(request):
21    return request.param
22
23
24def diag(values, m44_cls):
25    m = m44_cls()
26    for i, value in enumerate(values):
27        m[i, i] = value
28    return m
29
30
31def equal_matrix(m1, m2, abs_tol=1e-9):
32    for row in range(4):
33        for col in range(4):
34            if not isclose(m1[row, col], m2[row, col], abs_tol=abs_tol):
35                return False
36    return True
37
38
39class TestMatrix44:
40    @pytest.mark.parametrize('index', [0, 1, 2, 3])
41    def test_default_constructor(self, index, m44):
42        matrix = m44()
43        assert matrix[index, index] == 1.
44
45    def test_numbers_constructor(self, m44):
46        matrix = m44([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
47        assert matrix.get_row(0) == (0.0, 1.0, 2.0, 3.0)
48        assert matrix.get_row(1) == (4.0, 5.0, 6.0, 7.0)
49        assert matrix.get_row(2) == (8.0, 9.0, 10.0, 11.0)
50        assert matrix.get_row(3) == (12.0, 13.0, 14.0, 15.0)
51
52    def test_row_constructor(self, m44):
53        matrix = m44(
54            (0, 1, 2, 3),
55            (4, 5, 6, 7),
56            (8, 9, 10, 11),
57            (12, 13, 14, 15)
58        )
59        assert matrix.get_row(0) == (0.0, 1.0, 2.0, 3.0)
60        assert matrix.get_row(1) == (4.0, 5.0, 6.0, 7.0)
61        assert matrix.get_row(2) == (8.0, 9.0, 10.0, 11.0)
62        assert matrix.get_row(3) == (12.0, 13.0, 14.0, 15.0)
63
64    def test_invalid_row_constructor(self, m44):
65        with pytest.raises(ValueError):
66            m44(
67                (0, 1, 2, 3),
68                (4, 5, 6, 7),
69                (8, 9, 10, 11),
70                (12, 13, 14, 15, 16)
71            )
72        with pytest.raises(ValueError):
73            m44(
74                (0, 1, 2, 3),
75                (4, 5, 6, 7),
76                (8, 9, 10, 11),
77                (12, 13, 14,),
78            )
79
80    def test_invalid_number_constructor(self, m44):
81        pytest.raises(ValueError, m44, range(17))
82        pytest.raises(ValueError, m44, range(15))
83
84    def test_get_item_does_not_support_slicing(self, m44):
85        with pytest.raises(TypeError):
86            _ = m44()[:]
87
88    def test_get_item_index_error(self, m44):
89        with pytest.raises(IndexError):
90            _ = m44()[(-1, -1)]
91        with pytest.raises(IndexError):
92            _ = m44()[(0, 4)]
93        with pytest.raises(IndexError):
94            _ = m44()[(1, -1)]
95        with pytest.raises(IndexError):
96            _ = m44()[4, 4]
97
98    def test_set_item_does_not_support_slicing(self, m44):
99        with pytest.raises(TypeError):
100            m44()[:] = (1, 2)
101
102    def test_set_item_index_error(self, m44):
103        with pytest.raises(IndexError):
104            m44()[-1, -1] = 0
105        with pytest.raises(IndexError):
106            m44()[4, 4] = 0
107
108    def test_iter(self, m44):
109        values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
110        matrix = m44(values)
111        for v1, m1 in zip(values, matrix):
112            assert v1 == m1
113
114    def test_copy(self, m44):
115        values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
116        m1 = m44(values)
117        matrix = m1.copy()
118        for v1, m1 in zip(values, matrix):
119            assert v1 == m1
120
121    def test_get_row_index_error(self, m44):
122        with pytest.raises(IndexError):
123            m44().get_row(-1)
124        with pytest.raises(IndexError):
125            m44().get_row(4)
126
127    def test_set_row(self, m44):
128        matrix = m44()
129        matrix.set_row(0, (2., 3., 4., 5.))
130        assert matrix.get_row(0) == (2.0, 3.0, 4.0, 5.0)
131        matrix.set_row(1, (6., 7., 8., 9.))
132        assert matrix.get_row(1) == (6.0, 7.0, 8.0, 9.0)
133        matrix.set_row(2, (10., 11., 12., 13.))
134        assert matrix.get_row(2) == (10.0, 11.0, 12.0, 13.0)
135        matrix.set_row(3, (14., 15., 16., 17.))
136        assert matrix.get_row(3) == (14.0, 15.0, 16.0, 17.0)
137
138    def test_set_row_index_error(self, m44):
139        with pytest.raises(IndexError):
140            m44().set_row(-1, (0,))
141        with pytest.raises(IndexError):
142            m44().set_row(4, (0,))
143
144    def test_get_col(self, m44):
145        matrix = m44()
146        assert matrix.get_col(0) == (1.0, 0.0, 0.0, 0.0)
147        assert matrix.get_col(1) == (0.0, 1.0, 0.0, 0.0)
148        assert matrix.get_col(2) == (0.0, 0.0, 1.0, 0.0)
149        assert matrix.get_col(3) == (0.0, 0.0, 0.0, 1.0)
150
151    def test_get_col_index_error(self, m44):
152        with pytest.raises(IndexError):
153            m44().get_col(-1)
154        with pytest.raises(IndexError):
155            m44().get_col(4)
156
157    def test_set_col(self, m44):
158        matrix = m44()
159        matrix.set_col(0, (2., 3., 4., 5.))
160        assert matrix.get_col(0) == (2.0, 3.0, 4.0, 5.0)
161        matrix.set_col(1, (6., 7., 8., 9.))
162        assert matrix.get_col(1) == (6.0, 7.0, 8.0, 9.0)
163        matrix.set_col(2, (10., 11., 12., 13.))
164        assert matrix.get_col(2) == (10.0, 11.0, 12.0, 13.0)
165        matrix.set_col(3, (14., 15., 16., 17.))
166        assert matrix.get_col(3) == (14.0, 15.0, 16.0, 17.0)
167
168    def test_set_col_index_error(self, m44):
169        with pytest.raises(IndexError):
170            m44().set_col(-1, (0,))
171        with pytest.raises(IndexError):
172            m44().set_col(4, (0,))
173
174    def test_translate(self, m44):
175        t = m44.translate(10, 20, 30)
176        x = diag((1., 1., 1., 1.), m44)
177        x[3, 0] = 10.
178        x[3, 1] = 20.
179        x[3, 2] = 30.
180        assert equal_matrix(t, x) is True
181
182    def test_scale(self, m44):
183        t = m44.scale(10, 20, 30)
184        x = diag((10., 20., 30., 1.), m44)
185        assert equal_matrix(t, x) is True
186
187    def test_x_rotate(self, m44):
188        alpha = radians(25)
189        t = m44.x_rotate(alpha)
190        x = diag((1., 1., 1., 1.), m44)
191        x[1, 1] = cos(alpha)
192        x[2, 1] = -sin(alpha)
193        x[1, 2] = sin(alpha)
194        x[2, 2] = cos(alpha)
195        assert equal_matrix(t, x) is True
196
197    def test_y_rotate(self, m44):
198        alpha = radians(25)
199        t = m44.y_rotate(alpha)
200        x = diag((1., 1., 1., 1.), m44)
201        x[0, 0] = cos(alpha)
202        x[2, 0] = sin(alpha)
203        x[0, 2] = -sin(alpha)
204        x[2, 2] = cos(alpha)
205        assert equal_matrix(t, x) is True
206
207    def test_z_rotate(self, m44):
208        alpha = radians(25)
209        t = m44.z_rotate(alpha)
210        x = diag((1., 1., 1., 1.), m44)
211        x[0, 0] = cos(alpha)
212        x[1, 0] = -sin(alpha)
213        x[0, 1] = sin(alpha)
214        x[1, 1] = cos(alpha)
215        assert equal_matrix(t, x) is True
216
217    def test_chain(self, m44):
218        s = m44.scale(10, 20, 30)
219        t = m44.translate(10, 20, 30)
220
221        c = m44.chain(s, t)
222        x = diag((10., 20., 30., 1.), m44)
223        x[3, 0] = 10.
224        x[3, 1] = 20.
225        x[3, 2] = 30.
226        assert equal_matrix(c, x) is True
227
228    def test_chain2(self, m44):
229        s = m44.scale(10, 20, 30)
230        t = m44.translate(10, 20, 30)
231        r = m44.axis_rotate(angle=pi / 2, axis=(0., 0., 1.))
232        points = ((23., 97., .5), (2., 7., 13.))
233
234        p1 = s.transform_vertices(points)
235        p1 = t.transform_vertices(p1)
236        p1 = r.transform_vertices(p1)
237
238        c = m44.chain(s, t, r)
239        p2 = c.transform_vertices(points)
240        assert close_vectors(p1, p2) is True
241
242    def test_transform(self, m44):
243        t = m44.scale(2., .5, 1.)
244        r = t.transform((10., 20., 30.))
245        assert r == (20., 10., 30.)
246
247    def test_multiply(self, m44):
248        m1 = m44(range(16))
249        m2 = m44(range(16))
250        res = m1 * m2
251        expected = m44(
252            (56.0, 62.0, 68.0, 74.0),
253            (152.0, 174.0, 196.0, 218.0),
254            (248.0, 286.0, 324.0, 362.0),
255            (344.0, 398.0, 452.0, 506.0)
256        )
257        assert equal_matrix(res, expected)
258        # __matmul__()
259        res = m1 @ m2
260        assert equal_matrix(res, expected)
261
262    def test_transpose(self, m44):
263        matrix = m44((0, 1, 2, 3),
264                     (4, 5, 6, 7),
265                     (8, 9, 10, 11),
266                     (12, 13, 14, 15))
267        matrix.transpose()
268        assert matrix.get_row(0) == (0.0, 4.0, 8.0, 12.0)
269        assert matrix.get_row(1) == (1.0, 5.0, 9.0, 13.0)
270        assert matrix.get_row(2) == (2.0, 6.0, 10.0, 14.0)
271        assert matrix.get_row(3) == (3.0, 7.0, 11.0, 15.0)
272
273    def test_inverse_error(self, m44):
274        m = m44([1] * 16)
275        pytest.raises(ZeroDivisionError, m.inverse)
276
277    def test_axis_rotate_for_axis_normalization(self, m44):
278        m1 = m44.axis_rotate((0, 0, 1), 1.23)
279        m2 = m44.axis_rotate((0, 0, 0.5), 1.23)
280        for a, b in zip(m1, m2):
281            assert isclose(a, b)
282
283    def test_assign_after_initialised(self, m44):
284        matrix = m44()
285        matrix[0, 0] = 12
286        matrix2 = m44()
287        assert matrix2[0, 0] == 1
288
289        values = list(range(16))
290        matrix = m44(values)
291        matrix[0, 0] = 12
292        assert values[0] == 0
293        assert matrix[0, 0] == 12
294
295    def test_picklable(self, m44):
296        matrix = m44((0.1, 1, 2, 3),
297                     (4, 5, 6, 7),
298                     (8, 9, 10, 11),
299                     (12, 13, 14, 15))
300        pickled_matrix = pickle.loads(pickle.dumps(matrix))
301        assert equal_matrix(matrix, pickled_matrix)
302        assert type(matrix) is type(pickled_matrix)
303        matrix[0, 0] = 12
304        assert not equal_matrix(matrix, pickled_matrix)
305
306    def test_shear_xy(self, m44):
307        angle = pi / 4
308        matrix = m44.shear_xy(angle_x=angle, angle_y=-angle)
309        assert matrix[0, 1] == pytest.approx(-1.0)
310        assert matrix[1, 0] == pytest.approx(1.0)
311