1from __future__ import absolute_import, print_function, division
2import unittest
3
4import numpy as np
5import numpy.linalg
6from numpy.testing import assert_array_almost_equal
7from numpy.testing import dec, assert_array_equal, assert_allclose
8from numpy import inf
9
10import itertools
11
12import theano
13from theano import tensor, function, grad
14from theano.tensor.basic import _allclose
15from theano.tests.test_rop import break_op
16from theano.tests import unittest_tools as utt
17from theano import config
18from theano.tensor.slinalg import (
19    Cholesky, cholesky, CholeskyGrad, Solve, solve,
20    Eigvalsh, EigvalshGrad, eigvalsh, expm, kron)
21from theano.tests.unittest_tools import attr
22
23from nose.plugins.skip import SkipTest
24from nose.tools import assert_raises
25
26try:
27    import scipy.linalg
28    imported_scipy = True
29except ImportError:
30    # some ops (e.g. Cholesky, Solve, A_Xinv_b) won't work
31    imported_scipy = False
32
33
34def check_lower_triangular(pd, ch_f):
35    ch = ch_f(pd)
36    assert ch[0, pd.shape[1] - 1] == 0
37    assert ch[pd.shape[0] - 1, 0] != 0
38    assert np.allclose(np.dot(ch, ch.T), pd)
39    assert not np.allclose(np.dot(ch.T, ch), pd)
40
41
42def check_upper_triangular(pd, ch_f):
43    ch = ch_f(pd)
44    assert ch[4, 0] == 0
45    assert ch[0, 4] != 0
46    assert np.allclose(np.dot(ch.T, ch), pd)
47    assert not np.allclose(np.dot(ch, ch.T), pd)
48
49
50def test_cholesky():
51    if not imported_scipy:
52        raise SkipTest("Scipy needed for the Cholesky op.")
53
54    rng = np.random.RandomState(utt.fetch_seed())
55    r = rng.randn(5, 5).astype(config.floatX)
56    pd = np.dot(r, r.T)
57    x = tensor.matrix()
58    chol = cholesky(x)
59    # Check the default.
60    ch_f = function([x], chol)
61    yield check_lower_triangular, pd, ch_f
62    # Explicit lower-triangular.
63    chol = Cholesky(lower=True)(x)
64    ch_f = function([x], chol)
65    yield check_lower_triangular, pd, ch_f
66    # Explicit upper-triangular.
67    chol = Cholesky(lower=False)(x)
68    ch_f = function([x], chol)
69    yield check_upper_triangular, pd, ch_f
70    chol = Cholesky(lower=False, on_error='nan')(x)
71    ch_f = function([x], chol)
72    yield check_upper_triangular, pd, ch_f
73
74
75def test_cholesky_indef():
76    if not imported_scipy:
77        raise SkipTest("Scipy needed for the Cholesky op.")
78    x = tensor.matrix()
79    matrix = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
80    cholesky = Cholesky(lower=True, on_error='raise')
81    chol_f = function([x], cholesky(x))
82    with assert_raises(scipy.linalg.LinAlgError):
83        chol_f(matrix)
84    cholesky = Cholesky(lower=True, on_error='nan')
85    chol_f = function([x], cholesky(x))
86    assert np.all(np.isnan(chol_f(matrix)))
87
88
89def test_cholesky_grad():
90    if not imported_scipy:
91        raise SkipTest("Scipy needed for the Cholesky op.")
92    rng = np.random.RandomState(utt.fetch_seed())
93    r = rng.randn(5, 5).astype(config.floatX)
94
95    # The dots are inside the graph since Cholesky needs separable matrices
96
97    # Check the default.
98    yield (lambda: utt.verify_grad(lambda r: cholesky(r.dot(r.T)),
99                                   [r], 3, rng))
100    # Explicit lower-triangular.
101    yield (lambda: utt.verify_grad(lambda r: Cholesky(lower=True)(r.dot(r.T)),
102                                   [r], 3, rng))
103    # Explicit upper-triangular.
104    yield (lambda: utt.verify_grad(lambda r: Cholesky(lower=False)(r.dot(r.T)),
105                                   [r], 3, rng))
106
107
108def test_cholesky_grad_indef():
109    if not imported_scipy:
110        raise SkipTest("Scipy needed for the Cholesky op.")
111    x = tensor.matrix()
112    matrix = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
113    cholesky = Cholesky(lower=True, on_error='raise')
114    chol_f = function([x], grad(cholesky(x).sum(), [x]))
115    with assert_raises(scipy.linalg.LinAlgError):
116        chol_f(matrix)
117    cholesky = Cholesky(lower=True, on_error='nan')
118    chol_f = function([x], grad(cholesky(x).sum(), [x]))
119    assert np.all(np.isnan(chol_f(matrix)))
120
121
122@attr('slow')
123def test_cholesky_and_cholesky_grad_shape():
124    if not imported_scipy:
125        raise SkipTest("Scipy needed for the Cholesky op.")
126
127    rng = np.random.RandomState(utt.fetch_seed())
128    x = tensor.matrix()
129    for l in (cholesky(x), Cholesky(lower=True)(x), Cholesky(lower=False)(x)):
130        f_chol = theano.function([x], l.shape)
131        g = tensor.grad(l.sum(), x)
132        f_cholgrad = theano.function([x], g.shape)
133        topo_chol = f_chol.maker.fgraph.toposort()
134        topo_cholgrad = f_cholgrad.maker.fgraph.toposort()
135        if config.mode != 'FAST_COMPILE':
136            assert sum([node.op.__class__ == Cholesky
137                        for node in topo_chol]) == 0
138            assert sum([node.op.__class__ == CholeskyGrad
139                        for node in topo_cholgrad]) == 0
140        for shp in [2, 3, 5]:
141            m = np.cov(rng.randn(shp, shp + 10)).astype(config.floatX)
142            yield np.testing.assert_equal, f_chol(m), (shp, shp)
143            yield np.testing.assert_equal, f_cholgrad(m), (shp, shp)
144
145
146def test_eigvalsh():
147    if not imported_scipy:
148        raise SkipTest("Scipy needed for the geigvalsh op.")
149    import scipy.linalg
150
151    A = theano.tensor.dmatrix('a')
152    B = theano.tensor.dmatrix('b')
153    f = function([A, B], eigvalsh(A, B))
154
155    rng = np.random.RandomState(utt.fetch_seed())
156    a = rng.randn(5, 5)
157    a = a + a.T
158    for b in [10 * np.eye(5, 5) + rng.randn(5, 5)]:
159        w = f(a, b)
160        refw = scipy.linalg.eigvalsh(a, b)
161        np.testing.assert_array_almost_equal(w, refw)
162
163    # We need to test None separatly, as otherwise DebugMode will
164    # complain, as this isn't a valid ndarray.
165    b = None
166    B = theano.tensor.NoneConst
167    f = function([A], eigvalsh(A, B))
168    w = f(a)
169    refw = scipy.linalg.eigvalsh(a, b)
170    np.testing.assert_array_almost_equal(w, refw)
171
172
173def test_eigvalsh_grad():
174    if not imported_scipy:
175        raise SkipTest("Scipy needed for the geigvalsh op.")
176    import scipy.linalg
177
178    rng = np.random.RandomState(utt.fetch_seed())
179    a = rng.randn(5, 5)
180    a = a + a.T
181    b = 10 * np.eye(5, 5) + rng.randn(5, 5)
182    tensor.verify_grad(lambda a, b: eigvalsh(a, b).dot([1, 2, 3, 4, 5]),
183                       [a, b], rng=np.random)
184
185
186class test_Solve(utt.InferShapeTester):
187    def setUp(self):
188        super(test_Solve, self).setUp()
189        self.op_class = Solve
190        self.op = Solve()
191
192    def test_infer_shape(self):
193        if not imported_scipy:
194            raise SkipTest("Scipy needed for the Solve op.")
195        rng = np.random.RandomState(utt.fetch_seed())
196        A = theano.tensor.matrix()
197        b = theano.tensor.matrix()
198        self._compile_and_check([A, b],  # theano.function inputs
199                                [self.op(A, b)],  # theano.function outputs
200                                # A must be square
201                                [np.asarray(rng.rand(5, 5),
202                                               dtype=config.floatX),
203                                 np.asarray(rng.rand(5, 1),
204                                               dtype=config.floatX)],
205                                self.op_class,
206                                warn=False)
207        rng = np.random.RandomState(utt.fetch_seed())
208        A = theano.tensor.matrix()
209        b = theano.tensor.vector()
210        self._compile_and_check([A, b],  # theano.function inputs
211                                [self.op(A, b)],  # theano.function outputs
212                                # A must be square
213                                [np.asarray(rng.rand(5, 5),
214                                               dtype=config.floatX),
215                                 np.asarray(rng.rand(5),
216                                               dtype=config.floatX)],
217                                self.op_class,
218                                warn=False)
219
220    def test_solve_correctness(self):
221        if not imported_scipy:
222            raise SkipTest("Scipy needed for the Cholesky and Solve ops.")
223        rng = np.random.RandomState(utt.fetch_seed())
224        A = theano.tensor.matrix()
225        b = theano.tensor.matrix()
226        y = self.op(A, b)
227        gen_solve_func = theano.function([A, b], y)
228
229        cholesky_lower = Cholesky(lower=True)
230        L = cholesky_lower(A)
231        y_lower = self.op(L, b)
232        lower_solve_func = theano.function([L, b], y_lower)
233
234        cholesky_upper = Cholesky(lower=False)
235        U = cholesky_upper(A)
236        y_upper = self.op(U, b)
237        upper_solve_func = theano.function([U, b], y_upper)
238
239        b_val = np.asarray(rng.rand(5, 1), dtype=config.floatX)
240
241        # 1-test general case
242        A_val = np.asarray(rng.rand(5, 5), dtype=config.floatX)
243        # positive definite matrix:
244        A_val = np.dot(A_val.transpose(), A_val)
245        assert np.allclose(scipy.linalg.solve(A_val, b_val),
246                              gen_solve_func(A_val, b_val))
247
248        # 2-test lower traingular case
249        L_val = scipy.linalg.cholesky(A_val, lower=True)
250        assert np.allclose(scipy.linalg.solve_triangular(L_val, b_val, lower=True),
251                              lower_solve_func(L_val, b_val))
252
253        # 3-test upper traingular case
254        U_val = scipy.linalg.cholesky(A_val, lower=False)
255        assert np.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
256                              upper_solve_func(U_val, b_val))
257
258    def test_solve_dtype(self):
259        if not imported_scipy:
260            raise SkipTest("Scipy needed for the Solve op.")
261
262        dtypes = ['uint8', 'uint16', 'uint32', 'uint64',
263                  'int8', 'int16', 'int32', 'int64',
264                  'float16', 'float32', 'float64']
265
266        A_val = np.eye(2)
267        b_val = np.ones((2, 1))
268
269        # try all dtype combinations
270        for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
271            A = tensor.matrix(dtype=A_dtype)
272            b = tensor.matrix(dtype=b_dtype)
273            x = solve(A, b)
274            fn = function([A, b], x)
275            x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
276
277            assert x.dtype == x_result.dtype
278
279    def verify_solve_grad(self, m, n, A_structure, lower, rng):
280        # ensure diagonal elements of A relatively large to avoid numerical
281        # precision issues
282        A_val = (rng.normal(size=(m, m)) * 0.5 +
283                 np.eye(m)).astype(config.floatX)
284        if A_structure == 'lower_triangular':
285            A_val = np.tril(A_val)
286        elif A_structure == 'upper_triangular':
287            A_val = np.triu(A_val)
288        if n is None:
289            b_val = rng.normal(size=m).astype(config.floatX)
290        else:
291            b_val = rng.normal(size=(m, n)).astype(config.floatX)
292        eps = None
293        if config.floatX == "float64":
294            eps = 2e-8
295        solve_op = Solve(A_structure=A_structure, lower=lower)
296        utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
297
298    def test_solve_grad(self):
299        if not imported_scipy:
300            raise SkipTest("Scipy needed for the Solve op.")
301        rng = np.random.RandomState(utt.fetch_seed())
302        structures = ['general', 'lower_triangular', 'upper_triangular']
303        for A_structure in structures:
304            lower = (A_structure == 'lower_triangular')
305            self.verify_solve_grad(5, None, A_structure, lower, rng)
306            self.verify_solve_grad(6, 1, A_structure, lower, rng)
307            self.verify_solve_grad(4, 3, A_structure, lower, rng)
308        # lower should have no effect for A_structure == 'general' so also
309        # check lower=True case
310        self.verify_solve_grad(4, 3, 'general', lower=True, rng=rng)
311
312
313def test_expm():
314    if not imported_scipy:
315        raise SkipTest("Scipy needed for the expm op.")
316    rng = np.random.RandomState(utt.fetch_seed())
317    A = rng.randn(5, 5).astype(config.floatX)
318
319    ref = scipy.linalg.expm(A)
320
321    x = tensor.matrix()
322    m = expm(x)
323    expm_f = function([x], m)
324
325    val = expm_f(A)
326    np.testing.assert_array_almost_equal(val, ref)
327
328
329def test_expm_grad_1():
330    # with symmetric matrix (real eigenvectors)
331    if not imported_scipy:
332        raise SkipTest("Scipy needed for the expm op.")
333    rng = np.random.RandomState(utt.fetch_seed())
334    # Always test in float64 for better numerical stability.
335    A = rng.randn(5, 5)
336    A = A + A.T
337
338    tensor.verify_grad(expm, [A], rng=rng)
339
340
341def test_expm_grad_2():
342    # with non-symmetric matrix with real eigenspecta
343    if not imported_scipy:
344        raise SkipTest("Scipy needed for the expm op.")
345    rng = np.random.RandomState(utt.fetch_seed())
346    # Always test in float64 for better numerical stability.
347    A = rng.randn(5, 5)
348    w = rng.randn(5)**2
349    A = (np.diag(w**0.5)).dot(A + A.T).dot(np.diag(w**(-0.5)))
350    assert not np.allclose(A, A.T)
351
352    tensor.verify_grad(expm, [A], rng=rng)
353
354
355def test_expm_grad_3():
356    # with non-symmetric matrix (complex eigenvectors)
357    if not imported_scipy:
358        raise SkipTest("Scipy needed for the expm op.")
359    rng = np.random.RandomState(utt.fetch_seed())
360    # Always test in float64 for better numerical stability.
361    A = rng.randn(5, 5)
362
363    tensor.verify_grad(expm, [A], rng=rng)
364
365
366class TestKron(utt.InferShapeTester):
367
368    rng = np.random.RandomState(43)
369
370    def setUp(self):
371        super(TestKron, self).setUp()
372        self.op = kron
373
374    def test_perform(self):
375        if not imported_scipy:
376            raise SkipTest('kron tests need the scipy package to be installed')
377
378        for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
379            x = tensor.tensor(dtype='floatX',
380                              broadcastable=(False,) * len(shp0))
381            a = np.asarray(self.rng.rand(*shp0)).astype(config.floatX)
382            for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
383                if len(shp0) + len(shp1) == 2:
384                    continue
385                y = tensor.tensor(dtype='floatX',
386                                  broadcastable=(False,) * len(shp1))
387                f = function([x, y], kron(x, y))
388                b = self.rng.rand(*shp1).astype(config.floatX)
389                out = f(a, b)
390                # Newer versions of scipy want 4 dimensions at least,
391                # so we have to add a dimension to a and flatten the result.
392                if len(shp0) + len(shp1) == 3:
393                    scipy_val = scipy.linalg.kron(
394                        a[np.newaxis, :], b).flatten()
395                else:
396                    scipy_val = scipy.linalg.kron(a, b)
397                utt.assert_allclose(out, scipy_val)
398
399    def test_numpy_2d(self):
400        for shp0 in [(2, 3)]:
401            x = tensor.tensor(dtype='floatX',
402                              broadcastable=(False,) * len(shp0))
403            a = np.asarray(self.rng.rand(*shp0)).astype(config.floatX)
404            for shp1 in [(6, 7)]:
405                if len(shp0) + len(shp1) == 2:
406                    continue
407                y = tensor.tensor(dtype='floatX',
408                                  broadcastable=(False,) * len(shp1))
409                f = function([x, y], kron(x, y))
410                b = self.rng.rand(*shp1).astype(config.floatX)
411                out = f(a, b)
412                assert np.allclose(out, np.kron(a, b))
413