1from __future__ import absolute_import, print_function, division
2from copy import copy
3from itertools import product as itertools_product
4from unittest import TestCase
5
6import numpy as np
7from numpy import (arange, array, common_type, complex64, complex128, float32,
8                   float64, newaxis, shape, transpose, zeros)
9from numpy.testing import assert_array_almost_equal
10from itertools import product
11from six.moves import xrange
12
13import theano
14import theano.tensor as T
15from theano import tensor, In, shared, config
16from theano.compat import exc_message
17from theano.printing import pp
18from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
19                                _is_real_matrix, _gemm_canonicalize,
20                                _factor_canonicalized, Gemm, Gemv,
21                                gemm_inplace, gemm_no_inplace,
22                                InconsistencyError, Ger, ger, ger_destructive)
23from theano.tests import unittest_tools
24from .test_basic import (as_tensor_variable, inplace_func,
25                         compile, inplace)
26import theano.tensor.blas_scipy
27from theano.tests.unittest_tools import attr
28
29
30if config.mode == 'FAST_COMPILE':
31    mode_not_fast_compile = 'FAST_RUN'
32else:
33    mode_not_fast_compile = config.mode
34
35mode_blas_opt = theano.compile.get_default_mode().including(
36    'BlasOpt', 'specialize', 'InplaceBlasOpt')
37mode_blas_opt = mode_blas_opt.excluding('c_blas')
38
39
40def test_dot_eq():
41    assert T.Dot() == T.Dot()
42
43
44def sharedX(x, name):
45    return theano.shared(np.asarray(x, config.floatX), name=name)
46
47
48class t_gemm(TestCase):
49    """
50    This test suite is supposed to establish that gemm works as it is supposed to.
51    """
52
53    def setUp(self):
54        unittest_tools.seed_rng()
55        Gemm.debug = False
56
57    @staticmethod
58    def _gemm(z, a, x, y, b):
59        assert a.shape == ()
60        assert b.shape == ()
61        return b * z + a * np.dot(x, y)
62
63    @staticmethod
64    def rand(*args):
65        return np.random.rand(*args)
66
67    def cmp(self, z_, a_, x_, y_, b_):
68        for dtype in ['float32', 'float64', 'complex64', 'complex128']:
69            z = np.asarray(z_, dtype=dtype)
70            a = np.asarray(a_, dtype=dtype)
71            x = np.asarray(x_, dtype=dtype)
72            y = np.asarray(y_, dtype=dtype)
73            b = np.asarray(b_, dtype=dtype)
74
75            def cmp_linker(z, a, x, y, b, l):
76                z, a, x, y, b = [np.asarray(p) for p in (z, a, x, y, b)]
77                z_orig = z.copy()
78                tz, ta, tx, ty, tb = [as_tensor_variable(p).type()
79                                      for p in (z, a, x, y, b)]
80
81                f = inplace_func([tz, ta, tx, ty, tb],
82                                 gemm_inplace(tz, ta, tx, ty, tb),
83                                 mode=compile.Mode(optimizer=None, linker=l))
84                f(z, a, x, y, b)
85                z_after = self._gemm(z_orig, a, x, y, b)
86
87                # print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
88                unittest_tools.assert_allclose(z_after, z)
89                if a == 0.0 and b == 1.0:
90                    return
91                elif z_orig.size == 0:
92                    self.assertTrue(z.size == 0)
93                else:
94                    self.assertFalse(np.all(z_orig == z))
95
96            cmp_linker(copy(z), a, x, y, b, 'c|py')
97            cmp_linker(copy(z), a, x, y, b, 'py')
98
99            if (not dtype.startswith("complex") and theano.config.cxx):
100                # If theano.config.blas.ldflags is empty, Theano will use
101                # a NumPy C implementation of [sd]gemm_.
102                cmp_linker(copy(z), a, x, y, b, 'c')
103
104    def test0a(self):
105        Gemm.debug = True
106        try:
107            gemm_no_inplace([1.], 1., [1.], [1.], 1.)
108        except TypeError as e:
109            if exc_message(e) is Gemm.E_rank:
110                return
111        self.fail()
112
113    def test0(self):
114        try:
115            self.cmp(1., 0., 1.0, 1.0, 1.0)
116        except TypeError as e:
117            if exc_message(e) is Gemm.E_rank:
118                return
119        self.fail()
120
121    def test2(self):
122        try:
123            self.cmp(2., 1.0, [3, 2, 1.], [[1], [2], [3.]], 1.0)
124        except TypeError as e:
125            self.assertTrue(exc_message(e) == Gemm.E_rank)
126            return
127        self.fail()
128
129    def test4(self):
130        self.cmp(self.rand(3, 4), 1.0, self.rand(3, 5), self.rand(5, 4), 0.0)
131
132    def test5(self):
133        self.cmp(self.rand(3, 4), 1.0,
134                 self.rand(3, 5), self.rand(5, 4), 1.0)
135
136    def test6(self):
137        self.cmp(self.rand(3, 4), 1.0,
138                 self.rand(3, 5), self.rand(5, 4), -1.0)
139
140    def test7(self):
141        self.cmp(self.rand(3, 4), 0.0,
142                 self.rand(3, 5), self.rand(5, 4), 0.0)
143
144    def test8(self):
145        self.cmp(self.rand(3, 4), 0.0,
146                 self.rand(3, 5), self.rand(5, 4), 0.6)
147
148    def test9(self):
149        self.cmp(self.rand(3, 4), 0.0,
150                 self.rand(3, 5), self.rand(5, 4), -1.0)
151
152    def test10(self):
153        self.cmp(self.rand(3, 4), -1.0, self.rand(3, 5), self.rand(5, 4), 0.0)
154
155    def test11(self):
156        self.cmp(self.rand(3, 4), -1.0,
157                 self.rand(3, 5), self.rand(5, 4), 1.0)
158
159    def test12(self):
160        self.cmp(self.rand(3, 4), -1.0,
161                 self.rand(3, 5), self.rand(5, 4), -1.0)
162
163    def test_shape_0(self):
164        self.cmp(self.rand(0, 4), -1.0, self.rand(0, 5), self.rand(5, 4), -1.0)
165        self.cmp(self.rand(3, 0), -1.0, self.rand(3, 5), self.rand(5, 0), -1.0)
166        self.cmp(self.rand(3, 4), -1.0, self.rand(3, 0), self.rand(0, 4), -1.0)
167        self.cmp(self.rand(0, 0), -1.0, self.rand(0, 5), self.rand(5, 0), -1.0)
168        self.cmp(self.rand(0, 0), -1.0, self.rand(0, 0), self.rand(0, 0), -1.0)
169
170    def test_factorised_scalar(self):
171        a = T.matrix()
172        b = T.matrix()
173        s = theano.shared(np.zeros((5, 5)).astype(config.floatX))
174
175        lr1 = T.constant(0.01).astype(config.floatX)
176        lr2 = T.constant(2).astype(config.floatX)
177        l2_reg = T.constant(0.0001).astype(config.floatX)
178
179        # test constant merge with gemm
180        f = theano.function([a, b], updates=[(s, lr1 * T.dot(a, b) +
181                            l2_reg * lr2 * s)],
182                            mode=mode_not_fast_compile).maker.fgraph.toposort()
183        # [Gemm{inplace}(<TensorType(float64, matrix)>, 0.01,
184        # <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
185        # 2e-06)]
186        assert len(f) == 1
187        assert f[0].op == gemm_inplace
188
189        # test factored scalar with merge
190        f = theano.function([a, b], updates=[(s, lr1 * (T.dot(a, b) -
191                                                        l2_reg * s))],
192                            mode=mode_not_fast_compile).maker.fgraph.toposort()
193        # [Gemm{inplace}(<TensorType(float64, matrix)>, 0.01,
194        # <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
195        # -2e-06)]
196        assert len(f) == 1
197        assert f[0].op == gemm_inplace
198
199        # test factored scalar with merge and neg
200        f = theano.function([a, b],
201                            updates=[(s, s - lr1 * (s * .0002 + T.dot(a, b)))],
202                            mode=mode_not_fast_compile).maker.fgraph.toposort()
203        # [Gemm{inplace}(<TensorType(float64, matrix)>, -0.01,
204        # <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
205        # 0.999998)]
206        assert len(f) == 1
207        assert f[0].op == gemm_inplace
208
209    def test_destroy_map0(self):
210        # test that only first input can be overwritten.
211        Z = as_tensor_variable(self.rand(2, 2))
212        try:
213            gemm_inplace(Z, 1.0, Z, Z, 1.0)
214        except InconsistencyError as e:
215            if exc_message(e) == Gemm.E_z_uniq:
216                return
217        self.fail()
218
219    def test_destroy_map1(self):
220        # test that only first input can be overwritten.
221        Z = as_tensor_variable(self.rand(2, 2))
222        A = as_tensor_variable(self.rand(2, 2))
223        try:
224            gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
225        except InconsistencyError as e:
226            if exc_message(e) == Gemm.E_z_uniq:
227                return
228        self.fail()
229
230    def test_destroy_map2(self):
231        # test that only first input can be overwritten.
232        Z = as_tensor_variable(self.rand(2, 2))
233        A = as_tensor_variable(self.rand(2, 2))
234        try:
235            gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
236        except InconsistencyError as e:
237            if exc_message(e) == Gemm.E_z_uniq:
238                return
239        self.fail()
240
241    def test_destroy_map3(self):
242        # test that only first input can be overwritten
243        Z = as_tensor_variable(self.rand(2, 2))
244        A = as_tensor_variable(self.rand(2, 2))
245        try:
246            gemm_inplace(Z, 1.0, Z, A, 1.0)
247        except InconsistencyError as e:
248            if exc_message(e) == Gemm.E_z_uniq:
249                return
250        self.fail()
251
252    def test_destroy_map4(self):
253        # test that dot args can be aliased
254        Z = shared(self.rand(2, 2), name='Z')
255        A = shared(self.rand(2, 2), name='A')
256        one = T.constant(1.0).astype(Z.dtype)
257        f = inplace_func([], gemm_inplace(Z, one, A, A, one))
258        f()
259        f = inplace_func([], gemm_inplace(Z, one, A, A.T, one))
260        f()
261
262    def test_transposes(self):
263        # three square matrices which are not contiguous
264        A = self.rand(4, 5)[:, :4]
265        B = self.rand(4, 5)[:, :4]
266        C = self.rand(4, 5)[:, :4]
267
268        def t(z, x, y, a=1.0, b=0.0, l='c|py', dt='float64'):
269            z, a, x, y, b = [theano._asarray(p, dtype=dt)
270                             for p in (z, a, x, y, b)]
271            # z_orig = z.copy()
272            z_after = self._gemm(z, a, x, y, b)
273
274            tz, ta, tx, ty, tb = [shared(p) for p in (z, a, x, y, b)]
275
276            # f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb),
277            #                 mode = compile.Mode(optimizer = None, linker=l))
278            # f(z, a, x, y, b)
279            f = inplace_func([], gemm_inplace(tz, ta, tx, ty, tb),
280                             mode=compile.Mode(optimizer=None, linker=l))
281            f()
282            unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
283            f()
284            unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
285            f()
286            unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
287
288            # tz.value *= 0 # clear z's value
289            y_T = ty.get_value(borrow=True).T
290            ty.set_value(tx.get_value(borrow=True).T, borrow=True)
291            tx.set_value(y_T, borrow=True)
292
293            f()
294            # test that the transposed version of multiplication gives
295            # same answer
296            unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True).T)
297
298        t(C, A, B)
299        t(C.T, A, B)
300        t(C, A.T, B, dt='float32')
301        t(C, A, B.T)
302        t(C.T, A.T, B)
303        t(C, A.T, B.T, dt='float32')
304        t(C.T, A, B.T)
305        t(C.T, A.T, B.T, dt='float32')
306
307        t(C, A[:, :2], B[:2, :])
308        t(C.T, A[:, :2], B[:2, :], dt='float32')
309        t(C, A[:2, :].T, B[:2, :])
310        t(C.T, A[:2, :].T, B[:2, :], dt='float32')
311        t(C, A[:2, :].T, B[:, :2].T)
312        t(C.T, A[:2, :].T, B[:, :2].T)
313
314        try:
315            t(C.T, A[:2, :], B[:, :2].T)
316        except ValueError as e:
317            if exc_message(e).find('aligned') >= 0:
318                return
319        self.fail()
320
321    def test_non_contiguous(self):
322        # Like test_transposes but with matrices without any
323        # continuous dimension
324        A = self.rand(4, 4, 3)
325        B = self.rand(4, 4, 3)
326        C = self.rand(4, 4, 3)
327
328        def t(z, x, y, a=1.0, b=0.0, l='c|py', dt='float64'):
329            z, a, x, y, b = [theano._asarray(p, dtype=dt)
330                             for p in (z, a, x, y, b)]
331            z_orig = z.copy()
332            z_after = np.zeros_like(z_orig)
333            for i in xrange(3):
334                z_after[:, :, i] = self._gemm(z[:, :, i], a,
335                                              x[:, :, i], y[:, :, i], b)
336
337            tz, ta, tx, ty, tb = [shared(p) for p in (z, a, x, y, b)]
338            for i in xrange(3):
339                f_i = inplace_func([],
340                                   gemm_inplace(tz[:, :, i],
341                                   ta, tx[:, :, i], ty[:, :, i], tb),
342                                   mode=compile.Mode(optimizer=None, linker=l))
343                for j in xrange(3):
344                    # tz will not _always_ be overwritten,
345                    # and adding update={...} in the call to function()
346                    # will create cycles, so we update by hand.
347                    z_i = f_i()
348                    z = tz.get_value(borrow=True, return_internal_type=True)
349                    z[:, :, i] = z_i
350
351                    unittest_tools.assert_allclose(z_after[:, :, i],
352                                                   tz.get_value(borrow=True)[:, :, i])
353
354                tz_i = gemm_no_inplace(tz[:, :, i], ta, tx[
355                    :, :, i], ty[:, :, i], tb)
356                g_i = theano.function(
357                    [], tz_i, updates=[(tz, T.set_subtensor(tz[:, :, i],
358                                                            tz_i))],
359                    mode=compile.Mode(optimizer=None, linker=l))
360                for j in xrange(3):
361                    g_i()
362                    unittest_tools.assert_allclose(z_after[:, :, i],
363                                                   tz.get_value(borrow=True)[:, :, i])
364
365        t(C, A, B)
366        t(C.transpose((1, 0, 2)), A, B)
367        t(C, A.transpose((1, 0, 2)), B, dt='float32')
368        t(C, A, B.transpose((1, 0, 2)))
369        t(C.transpose((1, 0, 2)), A.transpose((1, 0, 2)), B)
370        t(C, A.transpose((1, 0, 2)), B.transpose((1, 0, 2)), dt='float32')
371        t(C.transpose((1, 0, 2)), A, B.transpose((1, 0, 2)))
372        t(C.transpose((1, 0, 2)), A.transpose((1, 0, 2)), B.transpose((
373            1, 0, 2)), dt='float32')
374
375
376class TestGemmNoFlags(object):
377    gemm = gemm_no_inplace
378    M = 4
379    N = 5
380    K = 6
381    slice_step = 3
382
383    def setUp(self):
384        unittest_tools.seed_rng()
385
386    def get_variable(self, V, to_transpose, to_slice):
387        if to_transpose:
388            V = V.T
389        if to_slice:
390            V = V[::self.slice_step]
391        return V
392
393    def get_function(self, dtype,
394                     transpose_A=False, transpose_B=False, transpose_C=False,
395                     slice_A=False, slice_B=False, slice_C=False):
396        alpha = theano.tensor.scalar(dtype=dtype, name='alpha')
397        beta = theano.tensor.scalar(dtype=dtype, name='beta')
398        A = theano.tensor.matrix(dtype=dtype, name='A')
399        B = theano.tensor.matrix(dtype=dtype, name='B')
400        C = theano.tensor.matrix(dtype=dtype, name='C')
401
402        A1 = self.get_variable(A, transpose_A, slice_A)
403        B1 = self.get_variable(B, transpose_B, slice_B)
404        C1 = self.get_variable(C, transpose_C, slice_C)
405
406        return theano.function([alpha, A, B, beta, C], self.gemm(C1, alpha, A1, B1, beta))
407
408    def generate_value(self, dtype, width, height, to_transpose, to_slice):
409        if to_slice:
410            if to_transpose:
411                shape = (height, width * self.slice_step)
412            else:
413                shape = (width * self.slice_step, height)
414        else:
415            if to_transpose:
416                shape = (height, width)
417            else:
418                shape = (width, height)
419        return np.random.random(shape).astype(dtype)
420
421    def get_data(self, dtype, alpha, beta,
422                 transpose_A=False, transpose_B=False, transpose_C=False,
423                 slice_A=False, slice_B=False, slice_C=False):
424        A = self.generate_value(dtype, self.M, self.N, transpose_A, slice_A)
425        B = self.generate_value(dtype, self.N, self.K, transpose_B, slice_B)
426        C = self.generate_value(dtype, self.M, self.K, transpose_C, slice_C)
427        return (alpha, A, B, beta, C)
428
429    def get_value(self, V, to_transpose, to_slice):
430        if to_transpose:
431            V = V.T
432        if to_slice:
433            V = V[::self.slice_step]
434        return V
435
436    def compute_ref(self, alpha, A, B, beta, C,
437                    transpose_A, transpose_B, transpose_C,
438                    slice_A, slice_B, slice_C):
439        A = self.get_value(A, transpose_A, slice_A)
440        B = self.get_value(B, transpose_B, slice_B)
441        C = self.get_value(C, transpose_C, slice_C)
442        return alpha * np.dot(A, B) + beta * C
443
444    @theano.change_flags({'blas.ldflags': ''})
445    def run_gemm(self, dtype, ALPHA, BETA,
446                 transpose_A, transpose_B, transpose_C,
447                 slice_A, slice_B, slice_C):
448        f = self.get_function(dtype, transpose_A, transpose_B, transpose_C, slice_A, slice_B, slice_C)
449        values = self.get_data(dtype, ALPHA, BETA, transpose_A, transpose_B, transpose_C, slice_A, slice_B, slice_C)
450        assert any(isinstance(node.op, Gemm) for node in f.maker.fgraph.apply_nodes)
451        z_val = f(*values)
452        assert z_val.dtype == dtype
453        assert tuple(z_val.shape) == (self.M, self.K)
454        ref_val = self.compute_ref(*(values + (transpose_A, transpose_B, transpose_C, slice_A, slice_B, slice_C)))
455        unittest_tools.assert_allclose(ref_val, z_val)
456
457    def test_gemm(self):
458        dtypes = ('float32', 'float64')
459        scalars = (0, 1, -2)
460        booleans = (False, True)
461        # dtype, alpha, beta, transA, transB, transC, sliceA, sliceB, sliceC
462        iterables = [dtypes] + ([scalars] * 2) + ([booleans] * 6)
463        for dtype, alpha, beta, tA, tB, tC, sA, sB, sC in product(*iterables):
464            yield (self.run_gemm, dtype, alpha, beta, tA, tB, tC, sA, sB, sC)
465
466
467def test_res_is_a():
468    X, Y, Z, a, b = XYZab()
469
470    assert not res_is_a(a, T.sqrt)
471    assert not res_is_a(a + a, T.sqrt)
472    assert res_is_a(T.sqrt(a + a), T.sqrt)
473
474    # leave the maxclients  stuff untested because it requires being in an fgraph.
475
476
477class t_as_scalar(TestCase):
478    def test0(self):
479        # Test that it works on scalar constants
480        a = T.constant(2.5)
481        b = T.constant(np.asarray([[[0.5]]]))
482        b2 = b.dimshuffle()
483        assert b2.ndim == 0
484        d_a = T.DimShuffle([], [])(a)
485        d_b = T.DimShuffle([True, True, True], [0, 2, 1])(b)
486        d_a2 = T.DimShuffle([], ['x', 'x', 'x'])(a)
487
488        self.assertTrue(_as_scalar(a) == a)
489        self.assertTrue(_as_scalar(b) != b)
490        self.assertTrue(_as_scalar(d_a) != d_a)
491        self.assertTrue(_as_scalar(d_b) != d_b)
492        self.assertTrue(_as_scalar(d_a2) != d_a2)
493
494    def test1(self):
495        # Test that it fails on nonscalar constants
496        a = T.constant(np.ones(5))
497        self.assertTrue(_as_scalar(a) is None)
498        self.assertTrue(_as_scalar(T.DimShuffle([False], [0, 'x'])(a)) is None)
499
500    def test2(self):
501        # Test that it works on scalar variables
502        a = T.dscalar()
503        d_a = T.DimShuffle([], [])(a)
504        d_a2 = T.DimShuffle([], ['x', 'x'])(a)
505
506        self.assertTrue(_as_scalar(a) is a)
507        self.assertTrue(_as_scalar(d_a) is a)
508        self.assertTrue(_as_scalar(d_a2) is a)
509
510    def test3(self):
511        # Test that it fails on nonscalar variables
512        a = T.matrix()
513        self.assertTrue(_as_scalar(a) is None)
514        self.assertTrue(_as_scalar(T.DimShuffle([False, False],
515                                                [0, 'x', 1])(a)) is None)
516
517
518class T_real_matrix(TestCase):
519    def test0(self):
520        self.assertTrue(_is_real_matrix(T.DimShuffle([False, False],
521                                                     [1, 0])(T.matrix())))
522        self.assertTrue(not _is_real_matrix(T.DimShuffle([False],
523                                                         ['x', 0])
524                                            (T.dvector())))
525
526
527def fail(msg):
528    print('FAIL', msg)
529    assert False
530
531
532"""
533This test suite ensures that Gemm is inserted where it belongs, and
534that the resulting functions compute the same things as the originals.
535"""
536
537
538def XYZab():
539    return T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
540
541
542class Failure(Exception):
543    pass
544
545
546def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
547              max_graphlen=0, expected_nb_gemm=1):
548    try:
549        f = inplace_func(
550            [In(ii, mutable=True, allow_downcast=True) for ii in i],
551            o,
552            mode='FAST_RUN',
553            on_unused_input='ignore')
554        nb_gemm = 0
555        for node in f.maker.fgraph.apply_nodes:
556            if isinstance(node.op, T.Dot):
557                raise Failure('dot not changed to gemm_inplace in graph')
558            if node.op == _dot22:
559                raise Failure('_dot22 not changed to gemm_inplace in graph')
560            if node.op == gemm_inplace:
561                nb_gemm += 1
562        assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
563        g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
564                         allow_input_downcast=True, on_unused_input='ignore')
565        for node in g.maker.fgraph.apply_nodes:
566            if node.op == gemm_inplace:
567                raise Exception('gemm_inplace in original graph')
568
569        graphlen = len(f.maker.fgraph.toposort())
570        if max_graphlen and (graphlen <= max_graphlen):
571            # theano.printing.debugprint(f)
572            assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen)
573
574        rng = np.random.RandomState(unittest_tools.fetch_seed(234))
575        r0 = f(*[np.asarray(rng.randn(*sh), config.floatX)
576                 for sh in ishapes])
577        rng = np.random.RandomState(unittest_tools.fetch_seed(234))
578        r1 = g(*[np.asarray(rng.randn(*sh), config.floatX)
579                 for sh in ishapes])
580        max_abs_err = np.max(np.abs(r0[0] - r1[0]))
581        eps = 1.0e-8
582        if config.floatX == 'float32':
583            eps = 1.0e-6
584        if max_abs_err > eps:
585            raise Failure('GEMM is computing the wrong output. max_rel_err =',
586                          max_abs_err)
587    except Failure:
588        for node in f.maker.fgraph.toposort():
589            print('GRAPH', node)
590        raise
591
592
593@unittest_tools.assertFailure_fast
594def test_gemm_opt0():
595    # Many subgraphs whose dots can be eliminated
596    X, Y, Z, a, b = XYZab()
597
598    just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a + Z * b])
599    just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) + b * Z])
600    just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y)])
601    just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a - Z * b])
602    just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) - b * Z])
603    just_gemm([X, Y, Z, a, b], [b * Z - a * T.dot(X, Y)])
604
605    # with transposes (transposes should be pushed through dot in canonicalize)
606    just_gemm([X, Y, Z, a, b], [b * Z.T - a * T.dot(Y.T, X.T)])
607    just_gemm([X, Y, Z, a, b], [b * Z.T + a * b * T.dot(X, Y).T])
608    just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y).T],
609              ishapes=[(5, 3), (3, 4), (4, 5), (), ()])
610
611    # with N multiplications instead of just one
612    just_gemm([X, Y, Z, a, b], [(b * b) * Z * a + (a * a) * T.dot(X, Y) * b])
613    just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y)])
614    just_gemm([X, Y, Z, a, b], [Z * b + T.dot(X, Y)])
615    just_gemm([X, Y, Z, a, b], [Z + a * b * a * T.dot(X, Y)])
616    just_gemm([X, Y, Z, a, b], [(b * b) * Z * a - (a * a) * T.dot(X, Y) * b])
617    just_gemm([X, Y, Z, a, b], [Z - T.dot(X, Y)])
618    just_gemm([X, Y, Z, a, b], [Z * b - T.dot(X, Y)])
619    just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)])
620
621
622@unittest_tools.assertFailure_fast
623def test_gemm_opt_double_gemm():
624    # This is the pattern that shows up in the autoencoder
625    X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
626    R, S, c = T.matrix(), T.matrix(), T.scalar()
627
628    just_gemm([X, Y, Z, a, b, R, S, c],
629              [Z * c + a * T.dot(X, Y) + b * T.dot(R, S).T],
630              ishapes=[(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()],
631              expected_nb_gemm=2)
632
633    ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()]
634    i = [X, Y, Z, a, b, R, S, c]
635    o = [(a * T.dot(X, Y) +
636         gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype(config.floatX)))]
637    try:
638        f = inplace_func([In(ii, mutable=True) for ii in i], o,
639                         mode='FAST_RUN', on_unused_input='ignore')
640        for node in f.maker.fgraph.apply_nodes:
641            if isinstance(node.op, T.Dot):
642                raise Failure('dot in graph')
643            if node.op == _dot22:
644                raise Failure('_dot22 in graph')
645        g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
646                         on_unused_input='ignore')
647        # for node in g.maker.fgraph.apply_nodes:
648        #    if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
649
650        rng = np.random.RandomState(unittest_tools.fetch_seed(234))
651        r0 = f(*[np.asarray(rng.randn(*sh), config.floatX)
652                 for sh in ishapes])
653        rng = np.random.RandomState(unittest_tools.fetch_seed(234))
654        r1 = g(*[np.asarray(rng.randn(*sh), config.floatX)
655                 for sh in ishapes])
656        max_abs_err = np.max(np.abs(r0[0] - r1[0]))
657        eps = 1.0e-8
658        if config.floatX == 'float32':
659            eps = 1.0e-6
660        if max_abs_err > eps:
661            raise Failure(
662                'GEMM is computing the wrong output. max_rel_err =',
663                max_abs_err)
664    except Failure:
665        for node in f.maker.fgraph.toposort():
666            print('GRAPH', node)
667        raise
668
669
670def test_gemm_canonicalize():
671    X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
672        'a'), T.scalar('b')
673    c, d = T.scalar('c'), T.scalar('d')
674    u = T.row('u')
675    v = T.vector('v')
676    w = T.col('w')
677
678    can = []
679    _gemm_canonicalize(X + Y + Z, 1.0, can, 0)
680    assert can == [(1.0, X), (1.0, Y), (1.0, Z)]
681
682    can = []
683    _gemm_canonicalize(X + Y + u, 1.0, can, 0)
684    assert can == [(1.0, X), (1.0, Y), (1.0, u)], can
685
686    can = []
687    _gemm_canonicalize(X + Y + v, 1.0, can, 0)
688    # [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))]
689    assert can[:2] == [(1.0, X), (1.0, Y)]
690    assert isinstance(can[2], tuple)
691    assert len(can[2]) == 2
692    assert can[2][0] == 1.0
693    assert can[2][1].owner
694    assert isinstance(can[2][1].owner.op, T.DimShuffle)
695    assert can[2][1].owner.inputs == [v]
696
697    can = []
698    _gemm_canonicalize(X + Y + w, 1.0, can, 0)
699    assert can == [(1.0, X), (1.0, Y), (1.0, w)], can
700
701    can = []
702    _gemm_canonicalize(a * X + Y - b * Z * c, 1.0, can, 0)
703    assert can[0] == (a, X)
704    assert can[1] == (1.0, Y)
705    assert can[2][0].owner.op == T.mul
706    assert can[2][0].owner.inputs[0].owner.op == T.neg
707    assert can[2][0].owner.inputs[0].owner.inputs[0] == c
708    assert can[2][0].owner.inputs[1] == b
709
710    can = []
711    _gemm_canonicalize((-d) * X - (a * X + Y - b * Z * c), 1.0, can, 0)
712    # print can
713    assert can[0][0].owner.op == T.neg
714    assert can[0][0].owner.inputs[0] == d
715    assert can[0][1] == X
716    assert can[1][0].owner.op == T.neg
717    assert can[1][0].owner.inputs[0] == a
718    assert can[2] == (-1.0, Y)
719    assert can[3][0].owner.op == T.mul
720    assert can[3][0].owner.inputs == [c, b]
721
722
723def test_gemm_factor():
724    X, Y = T.matrix('X'), T.matrix('Y')
725
726    assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
727    assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)])
728
729
730def test_upcasting_scalar_nogemm():
731    # Test that the optimization does not crash when the scale has an incorrect
732    # dtype, and forces upcasting of the result
733    v = T.fmatrix('v')
734    w = T.fmatrix('w')
735    t = T.fmatrix('t')
736    alpha = T.dscalar('a')
737
738    rval = T.dot(w, v) * alpha + t
739
740    f = theano.function([w, v, t, alpha], rval)
741    t = f.maker.fgraph.toposort()
742    assert np.sum([isinstance(n.op, Gemm) for n in t]) == 0
743    # theano.printing.debugprint(f, print_type=True)
744
745    v = T.fmatrix('v')
746    w = T.fmatrix('w')
747    t = T.fmatrix('t')
748    alpha = T.cscalar('a')
749
750    on_opt_error = config.on_opt_error
751    try:
752        config.on_opt_error = 'raise'
753        rval = T.dot(w, v) * alpha + t
754        f = theano.function([w, v, t, alpha], rval)
755    finally:
756        config.on_opt_error = on_opt_error
757
758    t = f.maker.fgraph.toposort()
759    assert np.sum([isinstance(n.op, Gemm) for n in t]) == 0
760    # theano.printing.debugprint(f, print_type=True)
761
762
763def test_gemm_nested():
764    X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
765        'a'), T.scalar('b')
766    R, S, U, c, d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar(
767        'c'), T.scalar('d')
768
769    just_gemm([X, Y, Z, R, S, U, a, b, c, d],
770              [a * Z - b * (c * T.dot(X, Y) + d * Z)],
771              ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4),
772                       (2, 4), (), (), (), ()],
773              max_graphlen=1)
774    # print "---------------------"
775    just_gemm([X, Y, Z, R, S, U, a, b, c, d],
776              [a * Z - b * (c * T.dot(X, Y) + d * Z + c * Z)],
777              ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4),
778                       (2, 4), (), (), (), ()],
779              max_graphlen=1)
780    # print "---------------------"
781    just_gemm([X, Y, Z, R, S, U, a, b, c, d],
782              [a * Z - b * (c * T.dot(X, Y) + d * Z + c * U)],
783              ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4),
784                       (2, 4), (), (), (), ()],
785              max_graphlen=3)
786
787
788def test_gemm_opt_wishlist():
789    X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
790
791    # with >2 additions of the same T.dot(X,Y term
792    just_gemm([X, Y, Z, a, b],
793              [(b * b) * Z * a + (a * a) * T.dot(X, Y) + b * T.dot(X, Y)])
794
795    just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y) + T.dot(X, Y)])
796
797
798def test_gemm_with_vector():
799    # Many subgraphs whose dots can be eliminated.  This adds a
800    # vector two the previous test, which triggers the long-sought GEMM
801    # bug.
802
803    X, Y, Z, a, b = XYZab()
804    v = T.vector()
805
806    def my_just_gemm(o):
807        i = [X, Y, Z, a, b, v]
808        ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, )]
809        just_gemm(i, o, ishapes=ishapes)
810
811    my_just_gemm([v + T.dot(X, Y) * a + Z * b])
812    my_just_gemm([v + a * T.dot(X, Y) + b * Z])
813    my_just_gemm([v + b * Z + a * T.dot(X, Y)])
814    my_just_gemm([v + T.dot(X, Y) * a - Z * b])
815    my_just_gemm([v + a * T.dot(X, Y) - b * Z])
816    my_just_gemm([v + b * Z - a * T.dot(X, Y)])
817
818    # with N multiplications instead of just one
819    my_just_gemm([v + (b * b) * Z * a + (a * a) * T.dot(X, Y) * b])
820    my_just_gemm([v + Z + T.dot(X, Y)])
821    my_just_gemm([v + Z * b + T.dot(X, Y)])
822    my_just_gemm([v + Z + a * b * a * T.dot(X, Y)])
823    my_just_gemm([v + (b * b) * Z * a - (a * a) * T.dot(X, Y) * b])
824    my_just_gemm([Z - T.dot(X, Y) + v])
825    my_just_gemm([Z * b - T.dot(X, Y) + v])
826    my_just_gemm([Z - a * b * a * T.dot(X, Y) + v])
827
828
829def test_gemm_opt_vector_stuff():
830    X, Y, a = T.matrix(), T.matrix(), T.scalar()
831    u, v = T.vector(), T.vector()
832
833    f = inplace_func([a, u, v], a + T.dot(u, v), mode='FAST_RUN')
834    if gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]:
835        raise Failure('gemm_inplace in graph')
836
837    f = inplace_func([a, u, X, Y], a * u + T.dot(X, Y), mode='FAST_RUN')
838    if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
839        raise Failure('gemm_inplace in graph')
840
841
842def test_gemm_unrolled():
843    # This test that the gemm optimizer remove the dot22 that was
844    # present in the graph. Otherwise, this add a gemm, but still
845    # compute the dot22.
846
847    # This was not always the case in the with this the following code.
848
849    batch_size = 100
850    rep_size = 40
851    rng = np.random.RandomState([1, 2, 3])
852
853    for num_rounds in range(1, 10):
854        W = sharedX(rng.randn(rep_size, rep_size), name='W')
855        V = sharedX(np.zeros((batch_size, rep_size)), name='V')
856        H = sharedX(np.zeros((batch_size, rep_size)), name='H')
857        G = sharedX(np.zeros((batch_size, rep_size)), name='G')
858
859        cur_V = V
860        cur_H = H
861
862        def update_V(cur_H):
863            return T.nnet.sigmoid(T.dot(cur_H, W.T))
864
865        def update_H(cur_V):
866            return T.nnet.sigmoid(T.dot(cur_V, W) + T.dot(G, W.T))
867
868        for i in xrange(num_rounds):
869            cur_V = update_V(cur_H)
870            cur_H = update_H(cur_V)
871
872        unrolled_theano = theano.function([], updates=[(V, cur_V), (H, cur_H)],
873                                          name='unrolled_theano')
874        nb_dot = sum([1 for node in unrolled_theano.maker.fgraph.toposort()
875                      if isinstance(node.op, (theano.tensor.Dot,
876                                              theano.tensor.blas.Dot22,
877                                              theano.tensor.blas.Gemm))])
878        # Each num_rounds add 3 dot, but one of them is always the same.
879        # So the final graph should have 1 + 2* num_rounds dot variant op.
880        assert nb_dot == num_rounds * 2 + 1, nb_dot
881
882        unrolled_theano()
883
884
885def test_inplace0():
886    # should fail to insert gemm_inplace because gemm_inplace would
887    # create cycles
888    X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
889        'a'), T.scalar('b')
890    R, S, c = T.matrix('R'), T.matrix('S'), T.scalar('c')
891
892    f = inplace_func([Z, b, R, S],
893                     [Z * (Z + b * T.dot(R, S).T)], mode='FAST_RUN')
894    if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
895        print(pp(f.maker.fgraph.outputs[0]))
896        raise Failure('gemm_inplace in graph')
897    assert gemm_no_inplace in [n.op for n in f.maker.fgraph.apply_nodes]
898
899    # gemm_inplace should be inserted here, to work in-place on Z*c
900    f = inplace_func([X, Y, Z, a, b, R, S, c],
901                     [Z * (c * Z + a * T.dot(X, Y) + b * T.dot(R, S).T)],
902                     mode='FAST_RUN')
903    if (gemm_inplace not in [n.op for n in f.maker.fgraph.apply_nodes]):
904        theano.printing.debugprint(f)
905        raise Failure('no gemm_inplace in graph')
906
907
908def test_inplace1():
909    X, Y, Z, a, b = XYZab()
910    # with > 2 terms in the overall addition
911    f = inplace_func([X, Y, Z],
912                     [Z + Z + T.dot(X, Y)], mode='FAST_RUN')
913    # theano.printing.debugprint(f)
914    # it doesn't work inplace because we didn't mark Z as mutable input
915    assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
916
917
918def test_dot22():
919    for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
920        a = T.matrix(dtype=dtype1)
921        for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
922            b = T.matrix(dtype=dtype2)
923            f = theano.function([a, b], T.dot(a, b), mode=mode_blas_opt)
924            topo = f.maker.fgraph.toposort()
925            if dtype1 == dtype2:
926                assert _dot22 in [x.op for x in topo], (dtype1, dtype2)
927            else:
928                check = [isinstance(x.op, T.Dot) for x in topo]
929                assert any(check), (dtype1, dtype2)
930            rng = np.random.RandomState(unittest_tools.fetch_seed())
931
932            def cmp(a_shp, b_shp):
933                av = rng.uniform(size=a_shp).astype(dtype1)
934                bv = rng.uniform(size=b_shp).astype(dtype2)
935                f(av, bv)
936
937            cmp((3, 4), (4, 5))
938            cmp((0, 4), (4, 5))
939            cmp((3, 0), (0, 5))
940            cmp((3, 4), (4, 0))
941            cmp((0, 4), (4, 0))
942            cmp((0, 0), (0, 0))
943
944
945@attr('slow')
946def test_dot22scalar():
947    # including does not seem to work for 'local_dot_to_dot22' and
948    # 'local_dot22_to_dot22scalar'
949    # TODO: exclude other optimizations in BlasOpt?
950    # m = theano.compile.get_default_mode().including('local_dot_to_dot22',
951    #                           'local_dot22_to_dot22scalar','specialize')
952    # m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
953    rng = np.random.RandomState(unittest_tools.fetch_seed())
954    for dtype1 in ['complex64', 'complex128']:
955        a = T.matrix('a', dtype=dtype1)
956        for dtype2 in ['complex64', 'complex128']:
957            b = T.matrix('b', dtype=dtype2)
958            for dtype3 in ['complex64', 'complex128']:
959                c = T.matrix('c', dtype=dtype3)
960                for dtype4 in ['complex64', 'complex128']:
961                    cst = theano.tensor.basic.constant(.2, dtype=dtype4)
962                    cst2 = theano.tensor.basic.constant(.1, dtype=dtype4)
963
964                    def check_dot22scalar(func, len_topo_scalar=-1):
965                        topo = func.maker.fgraph.toposort()
966                        ops = [x.op for x in topo]
967                        dtype4_upcast = theano.scalar.upcast(dtype4, dtype1,
968                                                             dtype2)
969
970                        if dtype1 == dtype2 == dtype3 == dtype4_upcast:
971                            if len_topo_scalar > 0:
972                                assert len(topo) == len_topo_scalar
973                            assert _dot22scalar in ops, (dtype1, dtype2,
974                                                         dtype3, dtype4)
975                        elif dtype1 == dtype2 == dtype4_upcast:
976                            if not (len_topo_scalar > 0):
977                                assert len(topo) == len_topo_scalar
978                                assert _dot22scalar in ops, (dtype1, dtype2,
979                                                             dtype3, dtype4)
980                            else:
981                                # Currently there is a problem of
982                                # optimization order The constant get
983                                # upcasted to float64 before we try to
984                                # merge it with the dot22 of
985                                # float32. So this prevent the merge.
986                                assert _dot22scalar in ops or _dot22 in ops, (
987                                    dtype1, dtype2, dtype3, dtype4)
988
989                        elif dtype1 == dtype2:
990                            assert _dot22 in ops, (dtype1, dtype2,
991                                                   dtype3, dtype4)
992                        else:
993                            check = [isinstance(o, T.Dot) for o in ops]
994                            assert any(check), (dtype1, dtype2, dtype3, dtype4)
995
996                    def cmp(a_shp, b_shp, c_shp, sqr_shp=(5, 5)):
997                        av = rng.uniform(size=a_shp).astype(dtype1)
998                        bv = rng.uniform(size=b_shp).astype(dtype2)
999                        cv = rng.uniform(size=c_shp).astype(dtype3)
1000                        sv = rng.uniform(size=sqr_shp).astype(dtype1)
1001
1002                        if False:
1003                            f = theano.function([a, b], cst * T.dot(a, b),
1004                                                mode=mode_blas_opt)
1005                            f.maker.fgraph.toposort()
1006                            check_dot22scalar(f, 1)
1007
1008                            f(av, bv)
1009
1010                        if True:
1011                            f = theano.function([a, b, c],
1012                                                cst * c * T.dot(a, b),
1013                                                mode=mode_blas_opt)
1014                            f.maker.fgraph.toposort()
1015                            check_dot22scalar(f, 2)
1016
1017                            f(av, bv, cv)
1018
1019                        f = theano.function([a, b, c],
1020                                            c * cst * T.dot(a, b),
1021                                            mode=mode_blas_opt)
1022                        f.maker.fgraph.toposort()
1023                        check_dot22scalar(f, 2)
1024                        f(av, bv, cv)
1025
1026                        # Here, canonicalize also seems needed
1027                        # TODO: add only the optimizations needed?
1028                        m2 = mode_blas_opt.including('canonicalize')
1029                        f = theano.function([a, b, c],
1030                                            cst2 * c * cst * T.dot(a, b),
1031                                            mode=m2)
1032                        f.maker.fgraph.toposort()
1033                        check_dot22scalar(f, 2)
1034                        f(av, bv, cv)
1035
1036                        if dtype1 == dtype2 == dtype3:
1037                            f = theano.function([a, b, c],
1038                                                c * cst * a * T.dot(a, b),
1039                                                mode=m2)
1040                            f.maker.fgraph.toposort()
1041                            check_dot22scalar(f, 2)
1042                            f(sv, sv, sv)
1043
1044                            f = theano.function([a, b, c],
1045                                                cst * c * a * T.dot(a, b),
1046                                                mode=mode_blas_opt)
1047                            f.maker.fgraph.toposort()
1048                            # currently the canonizer don't always
1049                            # merge all Mul together...  dot22scalar
1050                            # optimizer does not do a recursive search
1051                            # therefore, it doesn't find potential
1052                            # matches of the scalar.  TODO: combine
1053                            # with the 'canonicalization' that is part
1054                            # of the Gemm optimizer.
1055                            #
1056                            #    assert _dot22scalar in [x.op for x in topo]
1057                            #    assert len(topo)==2
1058                            f(sv, sv, sv)
1059
1060                            f = theano.function([a, b, c],
1061                                                c * a * cst * T.dot(a, b),
1062                                                mode=m2)
1063                            f.maker.fgraph.toposort()
1064                            check_dot22scalar(f, 2)
1065                            f(sv, sv, sv)
1066
1067                    cmp((3, 4), (4, 5), (3, 5))
1068                    cmp((0, 4), (4, 5), (0, 5))
1069                    cmp((3, 0), (0, 5), (3, 5))
1070                    cmp((3, 4), (4, 0), (3, 0), (0, 0))
1071                    cmp((0, 4), (4, 0), (0, 0))
1072                    cmp((0, 0), (0, 0), (0, 0))
1073
1074
1075def test_dot22scalar_cast():
1076    # Test that in `dot22_to_dot22scalar` we properly cast integers to floats.
1077    # Note that this test was failing before d5ff6904.
1078    A = T.dmatrix()
1079    for scalar_int_type in T.int_dtypes:
1080        y = T.scalar(dtype=scalar_int_type)
1081        f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
1082        assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
1083    A = T.fmatrix()
1084    for scalar_int_type in T.int_dtypes:
1085        y = T.scalar(dtype=scalar_int_type)
1086        f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
1087        if scalar_int_type in ['int32', 'int64']:
1088            assert _dot22 in [x.op for x in f.maker.fgraph.toposort()]
1089        else:
1090            assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
1091
1092
1093def test_local_dot22_to_dot22scalar():
1094    # This test that the bug in gh-1507 is really fixed
1095    A = T.dmatrix()
1096    mode = theano.compile.mode.get_default_mode()
1097    opt = theano.tensor.opt.in2out(
1098        theano.tensor.blas.local_dot22_to_dot22scalar)
1099    mode = mode.__class__(optimizer=opt)
1100
1101    x = T.dscalar()
1102    y = T.dscalar()
1103    z = T.dscalar()
1104    # make sure to don't have dimshuffle as we don't opt those cases
1105    m = T.dmatrix()
1106    r = T.drow()
1107    for idx, node in enumerate([
1108        # Old working cases
1109        T.mul(_dot22(A, A), x),
1110        T.mul(_dot22(A, A), x, y),
1111        T.mul(_dot22(A, A), x, r),
1112        T.mul(_dot22(A, A), m, x),
1113        T.mul(_dot22(A, A), x, m),
1114        T.mul(_dot22(A, A), x, (m * y)),
1115        T.mul(_dot22(A, A), (m * y), x),
1116        T.mul(_dot22(A, A), x, (r * y)),
1117        T.mul(_dot22(A, A), (r * y), x),
1118        T.mul(_dot22(A, A), (x * y), (m * x)),
1119        T.mul(_dot22(A, A), (r * y), (y * x)),
1120
1121        # Case that was raising an assert that is fixed in gh-1507
1122        T.mul(_dot22(A, A), (m * y), m),
1123        T.mul(_dot22(A, A), m, (m * y)),
1124        T.mul(_dot22(A, A), (r * y), (m * x)),
1125
1126        # assert fixed in gh-1507 and opt case added in gh-1515
1127        T.mul(_dot22(A, A), (m * y * z), m),
1128        T.mul(_dot22(A, A), m, (m * y * z)),
1129
1130        # Opt case added in gh-1515
1131        T.mul(_dot22(A, A), T.mul(m, y, z), m),
1132        T.mul(_dot22(A, A), m, T.mul(m, y, z)),
1133
1134        # Case that opt later in gh-1515
1135        T.mul(_dot22(A, A), (r * m), (m * x)),
1136    ]):
1137        node2 = theano.tensor.blas.local_dot22_to_dot22scalar.transform(
1138            node.owner)
1139        assert node2
1140        f = theano.function([x, y, z, m, r, A], node,
1141                            mode=mode, on_unused_input='ignore')
1142        f(.1, .2, .3, [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10]])
1143
1144
1145def test_dot_w_self():
1146    # This can trigger problems in the optimization because what would
1147    # normally be a gemm must not be because the output is aliased to
1148    # one of the inputs.
1149
1150    A = shared(value=np.ones((2, 2)))
1151    B = T.matrix()
1152
1153    p = T.dot(A, A) * B
1154
1155    grad = T.grad(T.mean(p), A)
1156    f = theano.function([B], p, updates=[(A, A - grad)])
1157
1158    # tests correctness in debugmode
1159    f(np.asarray([[0, 1], [2, 3]], dtype=config.floatX))
1160
1161
1162###############################################################################
1163# Tests for Gemv
1164###############################################################################
1165
1166class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
1167    def test_dot_vv(self):
1168        # Currently we generate a gemv for that case
1169        rng = np.random.RandomState(unittest_tools.fetch_seed())
1170        v = theano.shared(np.array(rng.uniform(size=(2,)), dtype='float32'))
1171        w = theano.shared(np.array(rng.uniform(size=(2,)), dtype='float32'))
1172        f = theano.function([], theano.dot(v, w), mode=mode_blas_opt)
1173
1174        # Assert that the dot was optimized somehow
1175        self.assertFunctionContains0(f, T.dot)
1176        self.assertFunctionContains1(f, Gemv(True))
1177
1178        # Assert they produce the same output
1179        assert np.allclose(f(), np.dot(v.get_value(), w.get_value()))
1180
1181    def test_dot_vm(self):
1182        # Test vector dot matrix
1183        rng = np.random.RandomState(unittest_tools.fetch_seed())
1184        v = theano.shared(np.array(rng.uniform(size=(2,)), dtype='float32'))
1185        m = theano.shared(np.array(rng.uniform(size=(2, 3)), dtype='float32'))
1186        f = theano.function([], theano.dot(v, m), mode=mode_blas_opt)
1187
1188        # Assert that the dot was optimized somehow
1189        self.assertFunctionContains0(f, T.dot)
1190        self.assertFunctionContains1(f, Gemv(True))
1191
1192        # Assert they produce the same output
1193        assert np.allclose(f(), np.dot(v.get_value(), m.get_value()))
1194        # Assert it works when m has no contiguous dimension
1195        m.set_value(m.get_value(borrow=True)[::-1, ::-1], borrow=True)
1196        assert np.allclose(f(), np.dot(v.get_value(), m.get_value()))
1197
1198    def test_dot_mv(self):
1199        # Test matrix dot vector
1200        rng = np.random.RandomState(unittest_tools.fetch_seed())
1201        v = theano.shared(np.array(rng.uniform(size=(2,)), dtype='float32'))
1202        m = theano.shared(np.array(rng.uniform(size=(3, 2)), dtype='float32'))
1203        f = theano.function([], theano.dot(m, v), mode=mode_blas_opt)
1204
1205        # Assert that the dot was optimized somehow
1206        self.assertFunctionContains0(f, T.dot)
1207        self.assertFunctionContains1(f, Gemv(True))
1208
1209        # Assert they produce the same output
1210        assert np.allclose(f(), np.dot(m.get_value(), v.get_value()))
1211        # Assert it works when m has no contiguous dimension
1212        m.set_value(m.get_value(borrow=True)[::-1, ::-1], borrow=True)
1213        assert np.allclose(f(), np.dot(m.get_value(), v.get_value()))
1214
1215    @staticmethod
1216    def t_gemv1(m_shp):
1217        # test vector2+dot(matrix,vector1)
1218        rng = np.random.RandomState(unittest_tools.fetch_seed())
1219        v1 = theano.shared(np.array(rng.uniform(size=(m_shp[1],)),
1220                           dtype='float32'))
1221        v2_orig = np.array(rng.uniform(size=(m_shp[0],)), dtype='float32')
1222        v2 = theano.shared(v2_orig)
1223        m = theano.shared(np.array(rng.uniform(size=m_shp), dtype='float32'))
1224
1225        f = theano.function([], v2 + theano.dot(m, v1), mode=mode_blas_opt)
1226
1227        # Assert they produce the same output
1228        assert np.allclose(f(), np.dot(m.get_value(), v1.get_value()) + v2_orig)
1229        topo = f.maker.fgraph.toposort()
1230        assert len(topo) == 1
1231        assert isinstance(topo[0].op, Gemv)
1232        assert topo[0].op.inplace is False
1233
1234        # test the inplace version
1235        g = theano.function([], [], updates=[(v2, v2 + theano.dot(m, v1))],
1236                            mode=mode_blas_opt)
1237
1238        # Assert they produce the same output
1239        g()
1240        assert np.allclose(v2.get_value(), np.dot(m.get_value(),
1241                           v1.get_value()) + v2_orig)
1242        topo = g.maker.fgraph.toposort()
1243        assert len(topo) == 1
1244        assert isinstance(topo[0].op, Gemv)
1245        if config.mode != 'FAST_COMPILE':
1246            assert topo[0].op.inplace is True
1247
1248        # Do the same tests with a matrix with strides in both dimensions
1249        m.set_value(m.get_value(borrow=True)[::-1, ::-1],
1250                    borrow=True)
1251        v2.set_value(v2_orig)
1252        assert np.allclose(f(),
1253                           np.dot(m.get_value(), v1.get_value()) + v2_orig)
1254        g()
1255        assert np.allclose(v2.get_value(),
1256                           np.dot(m.get_value(), v1.get_value()) + v2_orig)
1257
1258    @attr('slow')
1259    def test_gemv1(self):
1260        self.t_gemv1((3, 2))
1261        self.t_gemv1((0, 2))
1262        self.t_gemv1((3, 0))
1263        self.t_gemv1((0, 0))
1264
1265    def test_gemv2(self):
1266        # test vector2+dot(vector1,matrix)
1267        rng = np.random.RandomState(unittest_tools.fetch_seed())
1268        v1 = theano.shared(np.array(rng.uniform(size=(2,)),
1269                           dtype='float32'))
1270        v2_orig = np.array(rng.uniform(size=(3,)), dtype='float32')
1271        v2 = theano.shared(v2_orig)
1272        m = theano.shared(np.array(rng.uniform(size=(2, 3)),
1273                          dtype='float32'))
1274
1275        f = theano.function([], v2 + theano.dot(v1, m), mode=mode_blas_opt)
1276
1277        # Assert they produce the same output
1278        assert np.allclose(f(),
1279                           np.dot(v1.get_value(), m.get_value()) +
1280                           v2.get_value())
1281        topo = f.maker.fgraph.toposort()
1282        assert sum(isinstance(node.op, Gemv) for node in topo) == 1
1283        assert topo[-1].op.inplace is False
1284
1285        # test the inplace version
1286        g = theano.function([], [], updates=[(v2, v2 + theano.dot(v1, m))],
1287                            mode=mode_blas_opt)
1288
1289        # Assert they produce the same output
1290        g()
1291        assert np.allclose(v2.get_value(),
1292                           np.dot(v1.get_value(), m.get_value()) + v2_orig)
1293        topo = g.maker.fgraph.toposort()
1294        assert sum(isinstance(node.op, Gemv) for node in topo) == 1
1295        if config.mode != 'FAST_COMPILE':
1296            assert topo[-1].op.inplace is True
1297
1298        # Do the same tests with a matrix with strides in both dimensions
1299        m.set_value(m.get_value(borrow=True)[::-1, ::-1],
1300                    borrow=True)
1301        v2.set_value(v2_orig)
1302        assert np.allclose(f(),
1303                           np.dot(v1.get_value(), m.get_value()) +
1304                           v2.get_value())
1305        g()
1306        assert np.allclose(v2.get_value(),
1307                           np.dot(v1.get_value(), m.get_value()) + v2_orig)
1308
1309    def test_gemv_broadcast(self):
1310        # test gemv with some broadcasted input
1311        rng = np.random.RandomState(unittest_tools.fetch_seed())
1312        v1 = theano.shared(np.array(rng.uniform(size=(2,)),
1313                                    dtype='float32'))
1314        v2_orig = np.array(rng.uniform(size=(1,)), dtype='float32')
1315        v2 = theano.shared(v2_orig)
1316        m = theano.shared(np.array(rng.uniform(size=(1, 2)),
1317                                   dtype='float32'),
1318                          broadcastable=(True, False))
1319        o = theano.dot(m, v1)
1320        f = theano.function([], o + v2, mode=mode_blas_opt)
1321
1322        # Assert they produce the same output
1323        assert np.allclose(
1324            f(),
1325            np.dot(m.get_value(), v1.get_value()) + v2.get_value())
1326        topo = f.maker.fgraph.toposort()
1327        assert sum(isinstance(node.op, Gemv) for node in topo) == 1
1328
1329        # call gemv directly for mixed broadcast pattern.
1330        o = theano.tensor.blas.gemv_no_inplace(v2, 0.5, m, v1, 0.25)
1331        f = theano.function([], o, mode=mode_blas_opt)
1332        assert np.allclose(
1333            f(),
1334            0.5 * np.dot(m.get_value(), v1.get_value()) + 0.25 * v2.get_value())
1335        topo = f.maker.fgraph.toposort()
1336        assert sum(isinstance(node.op, Gemv) for node in topo) == 1
1337
1338    def test_gemv_dimensions(self):
1339        A = T.matrix('A')
1340        x, y = T.vectors('x', 'y')
1341        alpha = theano.shared(theano._asarray(1.0, dtype=config.floatX),
1342                              name='alpha')
1343        beta = theano.shared(theano._asarray(1.0, dtype=config.floatX),
1344                             name='beta')
1345
1346        z = beta * y + alpha * T.dot(A, x)
1347        f = theano.function([A, x, y], z)
1348
1349        # Matrix value
1350        A_val = np.ones((5, 3), dtype=config.floatX)
1351        # Different vector length
1352        ones_3 = np.ones(3, dtype=config.floatX)
1353        ones_4 = np.ones(4, dtype=config.floatX)
1354        ones_5 = np.ones(5, dtype=config.floatX)
1355        ones_6 = np.ones(6, dtype=config.floatX)
1356
1357        f(A_val, ones_3, ones_5)
1358        f(A_val[::-1, ::-1], ones_3, ones_5)
1359        self.assertRaises(ValueError, f, A_val, ones_4, ones_5)
1360        self.assertRaises(ValueError, f, A_val, ones_3, ones_6)
1361        self.assertRaises(ValueError, f, A_val, ones_4, ones_6)
1362
1363# The following gemv tests were added in March 2011 by Ian Goodfellow
1364# and are based on the gemv tests from scipy
1365# http://projects.scipy.org/scipy/browser/trunk/scipy/linalg/tests/test_fblas.py?rev=6803
1366# NOTE: At the time these tests were written, theano did not have a
1367# conjugate function. If such a thing is ever added, the tests involving
1368# conjugate should be ported over as well.
1369
1370
1371def matrixmultiply(a, b):
1372    if len(b.shape) == 1:
1373        b_is_vector = True
1374        b = b[:, newaxis]
1375    else:
1376        b_is_vector = False
1377    assert a.shape[1] == b.shape[0]
1378    c = zeros((a.shape[0], b.shape[1]), common_type(a, b))
1379    for i in xrange(a.shape[0]):
1380        for j in xrange(b.shape[1]):
1381            s = 0
1382            for k in xrange(a.shape[1]):
1383                s += a[i, k] * b[k, j]
1384            c[i, j] = s
1385    if b_is_vector:
1386        c = c.reshape((a.shape[0],))
1387    return c
1388
1389
1390class BaseGemv(object):
1391    mode = mode_blas_opt  # can be overridden with self.mode
1392    shared = staticmethod(theano.shared)
1393
1394    def get_data(self, x_stride=1, y_stride=1):
1395        rng = np.random.RandomState(unittest_tools.fetch_seed())
1396        mult = array(1, dtype=self.dtype)
1397        if self.dtype in [complex64, complex128]:
1398            mult = array(1 + 1j, dtype=self.dtype)
1399        alpha = array(1., dtype=self.dtype) * mult
1400        beta = array(1., dtype=self.dtype) * mult
1401        a = rng.randn(3, 3).astype(self.dtype) * mult
1402        x = arange(shape(a)[0] * x_stride, dtype=self.dtype) * mult
1403        y = arange(shape(a)[1] * y_stride, dtype=self.dtype) * mult
1404        return alpha, beta, a, x, y
1405
1406    def test_simple(self):
1407        alpha, beta, a, x, y = [self.shared(value)
1408                                for value in self.get_data()]
1409        desired_oy = alpha.get_value() * matrixmultiply(a.get_value(), x.get_value()) + beta.get_value() * y.get_value()
1410
1411        oy = alpha * T.dot(a, x) + beta * y
1412
1413        oy_func = theano.function([], oy, mode=self.mode)
1414
1415        oy_func.maker.fgraph.toposort()
1416        self.assertFunctionContains1(oy_func, self.gemv)
1417
1418        oy_val = oy_func()
1419
1420        assert_array_almost_equal(desired_oy, oy_val)
1421
1422    def test_default_beta_y(self):
1423
1424        vs = self.get_data()
1425        alpha_v, beta_v, a_v, x_v, y_v = vs
1426        a = self.shared(a_v)
1427        x = self.shared(x_v)
1428
1429        desired_oy = matrixmultiply(a_v, x_v)
1430
1431        oy = T.dot(a, x)
1432
1433        oy_func = theano.function([], oy, mode=self.mode)
1434
1435        self.assertFunctionContains1(oy_func, self.gemv_inplace)
1436
1437        oy_v = oy_func()
1438        assert_array_almost_equal(desired_oy, oy_v)
1439
1440    def test_simple_transpose(self):
1441        vs = self.get_data()
1442        alpha_v, beta_v, a_v, x_v, y_v = vs
1443        alpha, beta, a, x, y = [self.shared(v) for v in vs]
1444
1445        desired_oy = alpha_v * matrixmultiply(transpose(a_v),
1446                                              x_v) + beta_v * y_v
1447
1448        oy = alpha * T.dot(a.T, x) + beta * y
1449
1450        oy_func = theano.function([], oy, mode=self.mode)
1451
1452        self.assertFunctionContains1(oy_func, self.gemv)
1453
1454        oy_v = oy_func()
1455        assert_array_almost_equal(desired_oy, oy_v)
1456
1457    def test_x_stride(self):
1458        vs = self.get_data(x_stride=2)
1459        alpha_v, beta_v, a_v, x_v, y_v = vs
1460        alpha, beta, a, x, y = [self.shared(v) for v in vs]
1461
1462        desired_oy = alpha_v * matrixmultiply(a_v, x_v[::2]) + beta_v * y_v
1463
1464        oy = alpha * T.dot(a, x[::2]) + beta * y
1465
1466        oy_func = theano.function([], oy, mode=self.mode)
1467
1468        self.assertFunctionContains1(oy_func, self.gemv)
1469
1470        oy_v = oy_func()
1471        assert_array_almost_equal(desired_oy, oy_v)
1472
1473    def test_x_stride_transpose(self):
1474        vs = self.get_data(x_stride=2)
1475        alpha_v, beta_v, a_v, x_v, y_v = vs
1476        alpha, beta, a, x, y = [self.shared(v) for v in vs]
1477
1478        desired_oy = alpha_v * matrixmultiply(transpose(a_v), x_v[::2]) + \
1479            beta_v * y_v
1480
1481        oy = alpha * T.dot(a.T, x[::2]) + beta * y
1482
1483        oy_func = theano.function([], oy, mode=self.mode)
1484
1485        self.assertFunctionContains1(oy_func, self.gemv)
1486
1487        oy_v = oy_func()
1488        assert_array_almost_equal(desired_oy, oy_v)
1489
1490    def test_y_stride(self):
1491        vs = self.get_data(y_stride=2)
1492        alpha_v, beta_v, a_v, x_v, y_v = vs
1493        alpha, beta, a, x, y = [self.shared(v) for v in vs]
1494
1495        desired_oy = alpha_v * matrixmultiply(a_v, x_v) + beta_v * y_v[::2]
1496
1497        oy = alpha * T.dot(a, x) + beta * y[::2]
1498
1499        oy_func = theano.function([], oy, mode=self.mode)
1500
1501        self.assertFunctionContains1(oy_func, self.gemv)
1502
1503        oy_v = oy_func()
1504        assert_array_almost_equal(desired_oy, oy_v)
1505
1506    def test_y_stride_transpose(self):
1507        vs = self.get_data(y_stride=2)
1508        alpha_v, beta_v, a_v, x_v, y_v = vs
1509        alpha, beta, a, x, y = [self.shared(v) for v in vs]
1510
1511        desired_oy = alpha_v * matrixmultiply(transpose(a_v),
1512                                              x_v) + beta_v * y_v[::2]
1513
1514        oy = alpha * T.dot(a.T, x) + beta * y[::2]
1515
1516        oy_func = theano.function([], oy, mode=self.mode)
1517
1518        self.assertFunctionContains1(oy_func, self.gemv)
1519
1520        oy_v = oy_func()
1521        assert_array_almost_equal(desired_oy, oy_v)
1522
1523    def test_a_strides(self):
1524        vs = self.get_data()
1525        alpha_v, beta_v, a_v, x_v, y_v = vs
1526        alpha, beta, a, x, y = [self.shared(v) for v in vs]
1527        a_v = a_v[::-1, ::-1]
1528        a.set_value(a.get_value(borrow=True,
1529                                return_internal_type=True)[::-1, ::-1],
1530                    borrow=True)
1531
1532        desired_oy = alpha_v * matrixmultiply(a_v, x_v) + beta_v * y_v
1533
1534        oy = alpha * T.dot(a, x) + beta * y
1535
1536        oy_func = theano.function([], oy, mode=self.mode)
1537
1538        self.assertFunctionContains1(oy_func, self.gemv)
1539
1540        oy_v = oy_func()
1541        assert_array_almost_equal(desired_oy, oy_v)
1542
1543    def test_a_strides_transpose(self):
1544        vs = self.get_data()
1545        alpha_v, beta_v, a_v, x_v, y_v = vs
1546        alpha, beta, a, x, y = [self.shared(v) for v in vs]
1547        a_v = a_v[::-1, ::-1]
1548        a.set_value(a.get_value(borrow=True,
1549                                return_internal_type=True)[::-1, ::-1],
1550                    borrow=True)
1551
1552        desired_oy = alpha_v * matrixmultiply(transpose(a_v),
1553                                              x_v) + beta_v * y_v
1554
1555        oy = alpha * T.dot(a.T, x) + beta * y
1556
1557        oy_func = theano.function([], oy, mode=self.mode)
1558
1559        self.assertFunctionContains1(oy_func, self.gemv)
1560
1561        oy_v = oy_func()
1562        assert_array_almost_equal(desired_oy, oy_v)
1563
1564    def test_upcasting_scalar_nogemv(self):
1565        # Test that the optimization does not crash when the scale has
1566        # an incorrect dtype, and forces upcasting of the result
1567        # We put this test in this class to test it on the gpu too.
1568        vs = self.get_data()
1569        alpha_v, beta_v, a_v, x_v, y_v = vs
1570        alpha_v = alpha_v.astype("float64")
1571        a_v = a_v.astype("float32")
1572        x_v = x_v.astype("float32")
1573        y_v = y_v.astype("float32")
1574
1575        alpha = T.dscalar('alpha')
1576        a = self.shared(a_v)
1577        x = self.shared(x_v)
1578        y = self.shared(y_v)
1579
1580        rval = T.dot(a, x) * alpha + y
1581
1582        f = theano.function([alpha], rval, mode=self.mode)
1583        # this function is currently optimized so that the gemv is
1584        # done inplace on a temporarily allocated-buffer, which is
1585        # then scaled by alpha and to t with a fused elemwise.
1586        n_gemvs = 0
1587        # theano.printing.debugprint(f, print_type=True)
1588        for node in f.maker.fgraph.toposort():
1589            if node.op == self.gemv_inplace:
1590                n_gemvs += 1
1591                assert node.outputs[0].dtype == 'float32'
1592        assert n_gemvs == 1, n_gemvs
1593        self.assertFunctionContains1(f, self.gemv_inplace)
1594        f(alpha_v)
1595
1596
1597class TestSgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin):
1598    dtype = float32
1599    gemv = theano.tensor.blas.gemv_no_inplace
1600    gemv_inplace = theano.tensor.blas.gemv_inplace
1601
1602
1603class TestDgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin):
1604    dtype = float64
1605    gemv = theano.tensor.blas.gemv_no_inplace
1606    gemv_inplace = theano.tensor.blas.gemv_inplace
1607
1608# The optimization to put Gemv don't work for complex type for now.
1609# See ticket 653.
1610# class TestCgemv(TestCase, BaseGemv):
1611#    dtype = complex64
1612
1613# class TestZgemv(TestCase, BaseGemv):
1614#    dtype = complex128
1615
1616###############################################################################
1617# Tests for Ger
1618###############################################################################
1619
1620
1621class TestGer_make_node(TestCase):
1622    def setUp(self):
1623        self.iv = T.tensor(dtype='int32', broadcastable=(False,))
1624        self.fv = T.tensor(dtype='float32', broadcastable=(False,))
1625        self.fv1 = T.tensor(dtype='float32', broadcastable=(True,))
1626        self.dv = T.tensor(dtype='float64', broadcastable=(False,))
1627        self.dv1 = T.tensor(dtype='float64', broadcastable=(True,))
1628        self.cv = T.tensor(dtype='complex64', broadcastable=(False,))
1629        self.zv = T.tensor(dtype='complex128', broadcastable=(False,))
1630
1631        self.fv_2 = T.tensor(dtype='float32', broadcastable=(False,))
1632        self.fv1_2 = T.tensor(dtype='float32', broadcastable=(True,))
1633        self.dv_2 = T.tensor(dtype='float64', broadcastable=(False,))
1634        self.dv1_2 = T.tensor(dtype='float64', broadcastable=(True,))
1635        self.cv_2 = T.tensor(dtype='complex64', broadcastable=(False,))
1636        self.zv_2 = T.tensor(dtype='complex128', broadcastable=(False,))
1637
1638        self.fm = T.fmatrix()
1639        self.dm = T.dmatrix()
1640        self.cm = T.cmatrix()
1641        self.zm = T.zmatrix()
1642
1643        self.fa = T.fscalar()
1644        self.da = T.dscalar()
1645        self.ca = T.cscalar()
1646        self.za = T.zscalar()
1647
1648    def test_works_on_all_valid_dtypes(self):
1649        self.assertEqual(self.fm.type,
1650                         ger(self.fm, self.fa, self.fv, self.fv_2).type)
1651        self.assertEqual(self.fm.type,
1652                         ger(self.fm, self.fa, self.fv, self.fv_2).type)
1653        self.assertEqual(self.fm.type,
1654                         ger(self.fm, self.fa, self.fv, self.fv_2).type)
1655        self.assertEqual(self.fm.type,
1656                         ger(self.fm, self.fa, self.fv, self.fv_2).type)
1657
1658    def test_fails_on_invalid_dtypes(self):
1659        self.assertRaises(TypeError,
1660                          ger, T.imatrix(), T.iscalar(), T.ivector(),
1661                          T.ivector())
1662
1663    def test_fails_for_nonscalar_alpha(self):
1664        self.assertRaises(TypeError,
1665                          ger, self.fm, self.fm, self.fv, self.fv_2)
1666        # boundary case - fv1 has the right dtype and could be dimshuffled to a
1667        # scalar, but that's not make_node's job.
1668        self.assertRaises(TypeError,
1669                          ger, self.fm, self.fv1, self.fv, self.fv_2)
1670        # actually doing the aforementioned dimshuffle makes it work
1671        self.assertEqual(self.fm.type,
1672                         ger(self.fm, self.fv1.dimshuffle(), self.fv,
1673                             self.fv_2).type)
1674
1675    def test_fails_for_nonmatrix_A(self):
1676        self.assertRaises(TypeError,
1677                          ger, self.fv, self.fa, self.fv, self.fv_2)
1678
1679    def test_fails_for_nonvector_x_or_y(self):
1680        self.assertRaises(TypeError,
1681                          ger, self.fm, self.fa,
1682                          self.fv.dimshuffle('x', 0), self.fv_2)
1683        self.assertRaises(TypeError,
1684                          ger, self.fm, self.fa,
1685                          self.fv, self.fv_2.dimshuffle('x', 0))
1686
1687    def test_fails_for_mixed_dtypes(self):
1688        self.assertRaises(TypeError, ger, self.dm, self.fa, self.fv, self.fv_2)
1689        self.assertRaises(TypeError, ger, self.fm, self.da, self.fv, self.fv_2)
1690        self.assertRaises(TypeError, ger, self.fm, self.fa, self.dv, self.fv_2)
1691        self.assertRaises(TypeError, ger, self.fm, self.fa, self.fv, self.dv_2)
1692        self.assertRaises(TypeError, ger, self.cm, self.fa, self.fv, self.dv_2)
1693        self.assertRaises(TypeError, ger, self.cm, self.fa, self.fv, self.zv_2)
1694
1695
1696class TestGer_OpContract(TestCase, unittest_tools.T_OpContractMixin):
1697    def setUp(self):
1698        self.ops = [ger, ger_destructive]
1699
1700    def clone(self, op):
1701        return Ger(op.destructive)
1702
1703
1704class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
1705    shared = staticmethod(theano.shared)
1706
1707    def setUp(self):
1708        self.mode = theano.compile.get_default_mode().including('fast_run')
1709        self.mode = self.mode.excluding('c_blas', 'scipy_blas')
1710        dtype = self.dtype = 'float64'  # optimization isn't dtype-dependent
1711        self.A = T.tensor(dtype=dtype, broadcastable=(False, False))
1712        self.a = T.tensor(dtype=dtype, broadcastable=())
1713        self.x = T.tensor(dtype=dtype, broadcastable=(False,))
1714        self.y = T.tensor(dtype=dtype, broadcastable=(False,))
1715        self.ger = ger
1716        self.ger_destructive = ger_destructive
1717        self.gemm = gemm_no_inplace
1718
1719    def function(self, inputs, outputs, updates=None):
1720        if updates is None:
1721            updates = []
1722        return theano.function(inputs, outputs, self.mode, updates=updates)
1723
1724    def b(self, bval):
1725        return T.as_tensor_variable(np.asarray(bval, dtype=self.dtype))
1726
1727    def test_b_0_triggers_ger(self):
1728        # test local_gemm_to_ger opt
1729        assert T.blas.local_gemm_to_ger.transform(
1730            gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
1731                            self.y.dimshuffle('x', 0), self.b(0)).owner)
1732
1733    def test_b_1_triggers_ger(self):
1734        # test local_gemm_to_ger opt
1735        assert T.blas.local_gemm_to_ger.transform(
1736            gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
1737                            self.y.dimshuffle('x', 0), self.b(1)).owner)
1738
1739    def test_b_other_does_not_triggers_ger(self):
1740        # test local_gemm_to_ger opt
1741        assert not T.blas.local_gemm_to_ger.transform(
1742            gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
1743                            self.y.dimshuffle('x', 0), self.b(1.5)).owner)
1744
1745    def test_b_nonconst_does_not_triggers_ger(self):
1746        # test local_gemm_to_ger opt
1747        assert not T.blas.local_gemm_to_ger.transform(
1748            gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
1749                            self.y.dimshuffle('x', 0), self.a).owner)
1750
1751    def test_outer(self):
1752        f = self.function([self.x, self.y], T.outer(self.x, self.y))
1753        self.assertFunctionContains(f, self.ger_destructive)
1754        f(np.random.rand(5).astype(self.dtype),
1755          np.random.rand(4).astype(self.dtype))
1756
1757    def test_A_plus_outer(self):
1758        f = self.function([self.A, self.x, self.y],
1759                          self.A + T.outer(self.x, self.y))
1760        self.assertFunctionContains(f, self.ger)
1761        f(np.random.rand(5, 4).astype(self.dtype),
1762          np.random.rand(5).astype(self.dtype),
1763          np.random.rand(4).astype(self.dtype))
1764        f(np.random.rand(5, 4).astype(self.dtype)[::-1, ::-1],
1765          np.random.rand(5).astype(self.dtype),
1766          np.random.rand(4).astype(self.dtype))
1767
1768    def test_A_plus_scaled_outer(self):
1769        f = self.function([self.A, self.x, self.y],
1770                          self.A + 0.1 * T.outer(self.x, self.y))
1771        self.assertFunctionContains(f, self.ger)
1772        f(np.random.rand(5, 4).astype(self.dtype),
1773          np.random.rand(5).astype(self.dtype),
1774          np.random.rand(4).astype(self.dtype))
1775        f(np.random.rand(5, 4).astype(self.dtype)[::-1, ::-1],
1776          np.random.rand(5).astype(self.dtype),
1777          np.random.rand(4).astype(self.dtype))
1778
1779    def test_scaled_A_plus_scaled_outer(self):
1780        f = self.function([self.A, self.x, self.y],
1781                          np.asarray(0.2, self.dtype) * self.A +
1782                          np.asarray(0.1, self.dtype) * T.outer(
1783                          self.x, self.y))
1784        # Why gemm? This make the graph simpler did we test that it
1785        # make it faster?
1786        self.assertFunctionContains(f, self.gemm)
1787        f(np.random.rand(5, 4).astype(self.dtype),
1788          np.random.rand(5).astype(self.dtype),
1789          np.random.rand(4).astype(self.dtype))
1790        f(np.random.rand(5, 4).astype(self.dtype)[::-1, ::-1],
1791          np.random.rand(5).astype(self.dtype),
1792          np.random.rand(4).astype(self.dtype))
1793
1794    def given_dtype(self, dtype, M, N):
1795        # test corner case shape and dtype
1796
1797        f = self.function([self.A, self.x, self.y],
1798                          self.A + 0.1 * T.outer(self.x, self.y))
1799        self.assertFunctionContains(f, self.ger)
1800        f(np.random.rand(M, N).astype(self.dtype),
1801          np.random.rand(M).astype(self.dtype),
1802          np.random.rand(N).astype(self.dtype))
1803        f(np.random.rand(M, N).astype(self.dtype)[::-1, ::-1],
1804          np.random.rand(M).astype(self.dtype),
1805          np.random.rand(N).astype(self.dtype))
1806
1807    def test_f32_0_0(self):
1808        return self.given_dtype('float32', 0, 0)
1809
1810    def test_f32_1_0(self):
1811        return self.given_dtype('float32', 1, 0)
1812
1813    def test_f32_0_1(self):
1814        return self.given_dtype('float32', 0, 1)
1815
1816    def test_f32_1_1(self):
1817        return self.given_dtype('float32', 1, 1)
1818
1819    def test_f32_4_4(self):
1820        return self.given_dtype('float32', 4, 4)
1821
1822    def test_f32_7_1(self):
1823        return self.given_dtype('float32', 7, 1)
1824
1825    def test_f32_1_2(self):
1826        return self.given_dtype('float32', 1, 2)
1827
1828    def test_f64_4_5(self):
1829        return self.given_dtype('float64', 4, 5)
1830
1831    def test_c64_7_1(self):
1832        return self.given_dtype('complex64', 7, 1)
1833
1834    def test_c128_1_9(self):
1835        return self.given_dtype('complex128', 1, 9)
1836
1837    def test_inplace(self):
1838        A = self.shared(np.random.rand(4, 5).astype(self.dtype))
1839        f = self.function([self.x, self.y], [],
1840                          updates=[(A, A + T.constant(0.1, dtype=self.dtype) *
1841                                   T.outer(self.x, self.y))])
1842        self.assertFunctionContains(f, self.ger_destructive)
1843        f(np.random.rand(4).astype(self.dtype),
1844          np.random.rand(5).astype(self.dtype))
1845
1846        A.set_value(
1847            A.get_value(borrow=True, return_internal_type=True)[::-1, ::-1],
1848            borrow=True)
1849        f(np.random.rand(4).astype(self.dtype),
1850          np.random.rand(5).astype(self.dtype))
1851
1852
1853class TestBlasStrides(TestCase):
1854    dtype = 'float64'
1855    shared = staticmethod(tensor._shared)
1856    mode = theano.compile.get_default_mode()
1857    mode = mode.including('fast_run').excluding('gpu', 'c_blas', 'scipy_blas')
1858    rng = np.random.RandomState(seed=unittest_tools.fetch_seed())
1859
1860    def rand(self, *shape):
1861        return theano._asarray(self.rng.rand(*shape), dtype=self.dtype)
1862
1863    def cmp_dot22(self, b_shp, c_shp):
1864        av = np.zeros((0, 0), dtype=self.dtype)
1865        bv = self.rand(*b_shp)
1866        cv = self.rand(*c_shp)
1867
1868        a = self.shared(av, 'a')
1869        b = self.shared(bv, 'b')
1870        c = self.shared(cv, 'c')
1871
1872        b_t = self.shared(bv.T, 'b.T')
1873        c_t = self.shared(cv.T, 'c.T')
1874
1875        b_dev = b.get_value(borrow=False, return_internal_type=True)
1876        c_dev = c.get_value(borrow=False, return_internal_type=True)
1877        bt_dev = b_t.get_value(borrow=False, return_internal_type=True)
1878        ct_dev = c_t.get_value(borrow=False, return_internal_type=True)
1879
1880        f_nn = theano.function([], [], updates=[(a, tensor.dot(b, c))],
1881                               mode=self.mode)
1882        # print 'class name:', self.__class__.__name__
1883        # theano.printing.debugprint(f_nn)
1884        f_nt = theano.function([], [], updates=[(a, tensor.dot(b, c_t.T))],
1885                               mode=self.mode)
1886        f_tn = theano.function([], [], updates=[(a, tensor.dot(b_t.T, c))],
1887                               mode=self.mode)
1888        f_tt = theano.function([], [], updates=[(a, tensor.dot(b_t.T, c_t.T))],
1889                               mode=self.mode)
1890
1891        # Try with all stride patterns, and all transposed pattern
1892        for step_signs in itertools_product((-1, 1), repeat=4):
1893            for step in (1, 2):
1894                b_step1, b_step2, c_step1, c_step2 = (s * step
1895                                                      for s in step_signs)
1896
1897                b.set_value(b_dev.copy()[::b_step1, ::b_step2], borrow=True)
1898                c.set_value(c_dev.copy()[::c_step1, ::c_step2], borrow=True)
1899                b_t.set_value(bt_dev.copy()[::b_step2, ::b_step1], borrow=True)
1900                c_t.set_value(ct_dev.copy()[::c_step2, ::c_step1], borrow=True)
1901
1902                # Numpy result
1903                a_n = np.dot(bv[::b_step1, ::b_step2],
1904                             cv[::c_step1, ::c_step2])
1905
1906                f_nn()
1907                assert np.allclose(a.get_value(), a_n)
1908
1909                f_nt()
1910                assert np.allclose(a.get_value(), a_n)
1911
1912                f_tn()
1913                assert np.allclose(a.get_value(), a_n)
1914
1915                f_tt()
1916                assert np.allclose(a.get_value(), a_n)
1917
1918    def test_dot22(self):
1919        self.cmp_dot22((3, 4), (4, 5))
1920        self.cmp_dot22((1, 4), (4, 5))
1921        self.cmp_dot22((3, 4), (4, 1))
1922        self.cmp_dot22((3, 1), (1, 1))
1923        self.cmp_dot22((1, 4), (4, 1))
1924        self.cmp_dot22((3, 1), (1, 5))
1925        self.cmp_dot22((0, 4), (4, 5))
1926        self.cmp_dot22((0, 4), (4, 1))
1927        self.cmp_dot22((0, 1), (1, 5))
1928        self.cmp_dot22((3, 4), (4, 0))
1929        self.cmp_dot22((3, 0), (0, 5))
1930        self.cmp_dot22((0, 4), (4, 0))
1931        self.cmp_dot22((0, 0), (0, 0))
1932
1933    def cmp_dot22scalar(self, b_shp, c_shp):
1934        av = np.zeros((0, 0), dtype=self.dtype)
1935        bv = self.rand(*b_shp)
1936        cv = self.rand(*c_shp)
1937        l = np.float32(0.2)
1938
1939        a = self.shared(av, 'a')
1940        b = self.shared(bv, 'b')
1941        c = self.shared(cv, 'c')
1942
1943        b_t = self.shared(bv.T, 'b.T')
1944        c_t = self.shared(cv.T, 'c.T')
1945
1946        b_dev = b.get_value(borrow=False, return_internal_type=True)
1947        c_dev = c.get_value(borrow=False, return_internal_type=True)
1948        bt_dev = b_t.get_value(borrow=False, return_internal_type=True)
1949        ct_dev = c_t.get_value(borrow=False, return_internal_type=True)
1950
1951        f_nn = theano.function([], [], updates=[(a, l * tensor.dot(b, c))],
1952                               mode=self.mode)
1953        f_nt = theano.function([], [], updates=[(a, l * tensor.dot(b, c_t.T))],
1954                               mode=self.mode)
1955        f_tn = theano.function([], [], updates=[(a, l * tensor.dot(b_t.T, c))],
1956                               mode=self.mode)
1957        f_tt = theano.function([], [],
1958                               updates=[(a, l * tensor.dot(b_t.T, c_t.T))],
1959                               mode=self.mode)
1960
1961        # Try with all stride patterns, and all transposed pattern
1962        for step_signs in itertools_product((-1, 1), repeat=4):
1963            for step in (1, 2):
1964                b_step1, b_step2, c_step1, c_step2 = (s * step
1965                                                      for s in step_signs)
1966
1967                b.set_value(b_dev.copy()[::b_step1, ::b_step2], borrow=True)
1968                c.set_value(c_dev.copy()[::c_step1, ::c_step2], borrow=True)
1969                b_t.set_value(bt_dev.copy()[::b_step2, ::b_step1], borrow=True)
1970                c_t.set_value(ct_dev.copy()[::c_step2, ::c_step1], borrow=True)
1971
1972                # Numpy result
1973                a_n = l * np.dot(bv[::b_step1, ::b_step2],
1974                                 cv[::c_step1, ::c_step2])
1975
1976                f_nn()
1977                assert np.allclose(a.get_value(), a_n)
1978
1979                f_nt()
1980                assert np.allclose(a.get_value(), a_n)
1981
1982                f_tn()
1983                assert np.allclose(a.get_value(), a_n)
1984
1985                f_tt()
1986                assert np.allclose(a.get_value(), a_n)
1987
1988    def test_dot22scalar(self):
1989        self.cmp_dot22scalar((3, 4), (4, 5))
1990        self.cmp_dot22scalar((1, 4), (4, 5))
1991        self.cmp_dot22scalar((3, 4), (4, 1))
1992        self.cmp_dot22scalar((3, 1), (1, 1))
1993        self.cmp_dot22scalar((1, 4), (4, 1))
1994        self.cmp_dot22scalar((3, 1), (1, 5))
1995        self.cmp_dot22scalar((0, 4), (4, 5))
1996        self.cmp_dot22scalar((0, 4), (4, 1))
1997        self.cmp_dot22scalar((0, 1), (1, 5))
1998        self.cmp_dot22scalar((3, 4), (4, 0))
1999        self.cmp_dot22scalar((3, 0), (0, 5))
2000        self.cmp_dot22scalar((0, 4), (4, 0))
2001        self.cmp_dot22scalar((0, 0), (0, 0))
2002
2003    def cmp_gemm(self, a_shp, b_shp, c_shp):
2004        av = self.rand(*a_shp)
2005        bv = self.rand(*b_shp)
2006        cv = self.rand(*c_shp)
2007        l = np.float32(0.2)
2008
2009        a = self.shared(av, 'a')
2010        b = self.shared(bv, 'b')
2011        c = self.shared(cv, 'c')
2012
2013        a_t = self.shared(av.T, 'a.T')
2014        b_t = self.shared(bv.T, 'b.T')
2015        c_t = self.shared(cv.T, 'c.T')
2016
2017        a_dev = a.get_value(borrow=False, return_internal_type=True)
2018        b_dev = b.get_value(borrow=False, return_internal_type=True)
2019        c_dev = c.get_value(borrow=False, return_internal_type=True)
2020        bt_dev = b_t.get_value(borrow=False, return_internal_type=True)
2021        ct_dev = c_t.get_value(borrow=False, return_internal_type=True)
2022
2023        f_nnn = theano.function(
2024            [], [],
2025            updates=[(a, (l * a + tensor.dot(b, c)))],
2026            mode=self.mode)
2027        f_nnt = theano.function(
2028            [], [],
2029            updates=[(a, (l * a + tensor.dot(b, c_t.T)))],
2030            mode=self.mode)
2031        f_ntn = theano.function(
2032            [], [],
2033            updates=[(a, (l * a + tensor.dot(b_t.T, c)))],
2034            mode=self.mode)
2035        f_ntt = theano.function(
2036            [], [],
2037            updates=[(a, (l * a + tensor.dot(b_t.T, c_t.T)))],
2038            mode=self.mode)
2039        f_tnn = theano.function(
2040            [], [],
2041            updates=[(a_t, (l * a_t + tensor.dot(b, c).T))],
2042            mode=self.mode)
2043        f_tnt = theano.function(
2044            [], [],
2045            updates=[(a_t, (l * a_t + tensor.dot(b, c_t.T).T))],
2046            mode=self.mode)
2047        f_ttn = theano.function(
2048            [], [],
2049            updates=[(a_t, (l * a_t + tensor.dot(b_t.T, c).T))],
2050            mode=self.mode)
2051        f_ttt = theano.function(
2052            [], [],
2053            updates=[(a_t, (l * a_t + tensor.dot(b_t.T, c_t.T).T))],
2054            mode=self.mode)
2055
2056        # Try with all stride patterns, and all transposed pattern
2057        for step_signs in itertools_product((-1, 1), repeat=6):
2058            for step in (1, 2):
2059                a_step1, a_step2, b_step1, b_step2, c_step1, c_step2 = \
2060                    (s * step for s in step_signs)
2061
2062                b.set_value(b_dev.copy()[::b_step1, ::b_step2], borrow=True)
2063                c.set_value(c_dev.copy()[::c_step1, ::c_step2], borrow=True)
2064                b_t.set_value(bt_dev.copy()[::b_step2, ::b_step1], borrow=True)
2065                c_t.set_value(ct_dev.copy()[::c_step2, ::c_step1], borrow=True)
2066
2067                # Numpy results
2068                a_n = (l * av[::a_step1, ::a_step2] +
2069                       np.dot(bv[::b_step1, ::b_step2],
2070                              cv[::c_step1, ::c_step2]))
2071                at_n = (l * av[::a_step1, ::a_step2].T +
2072                        np.dot(bv[::b_step1, ::b_step2],
2073                               cv[::c_step1, ::c_step2]).T)
2074
2075                # a's value is updated, so we need to reinitialize it each time
2076                a.set_value(a_dev.copy()[::a_step1, ::a_step2], borrow=True)
2077                f_nnn()
2078                assert np.allclose(a.get_value(), a_n)
2079
2080                a.set_value(a_dev.copy()[::a_step1, ::a_step2], borrow=True)
2081                f_nnt()
2082                assert np.allclose(a.get_value(), a_n)
2083
2084                a.set_value(a_dev.copy()[::a_step1, ::a_step2], borrow=True)
2085                f_ntn()
2086                assert np.allclose(a.get_value(), a_n)
2087
2088                a.set_value(a_dev.copy()[::a_step1, ::a_step2], borrow=True)
2089                f_ntt()
2090                assert np.allclose(a.get_value(), a_n)
2091
2092                a_t.set_value(transpose(a_dev.copy())[::a_step2, ::a_step1],
2093                              borrow=True)
2094                f_tnn()
2095                assert np.allclose(a_t.get_value(), at_n)
2096
2097                a_t.set_value(transpose(a_dev.copy())[::a_step2, ::a_step1],
2098                              borrow=True)
2099                f_tnt()
2100                assert np.allclose(a_t.get_value(), at_n)
2101
2102                a_t.set_value(transpose(a_dev.copy())[::a_step2, ::a_step1],
2103                              borrow=True)
2104                f_ttn()
2105                assert np.allclose(a_t.get_value(), at_n)
2106
2107                a_t.set_value(transpose(a_dev.copy())[::a_step2, ::a_step1],
2108                              borrow=True)
2109                f_ttt()
2110                assert np.allclose(a_t.get_value(), at_n)
2111
2112    def test_gemm(self):
2113        self.cmp_gemm((3, 5), (3, 4), (4, 5))
2114        self.cmp_gemm((1, 5), (1, 4), (4, 5))
2115        self.cmp_gemm((3, 1), (3, 4), (4, 1))
2116        self.cmp_gemm((3, 1), (3, 1), (1, 1))
2117        self.cmp_gemm((1, 1), (1, 4), (4, 1))
2118        self.cmp_gemm((3, 5), (3, 1), (1, 5))
2119        self.cmp_gemm((0, 5), (0, 4), (4, 5))
2120        self.cmp_gemm((0, 1), (0, 4), (4, 1))
2121        self.cmp_gemm((0, 5), (0, 1), (1, 5))
2122        self.cmp_gemm((3, 0), (3, 4), (4, 0))
2123        self.cmp_gemm((3, 5), (3, 0), (0, 5))
2124        self.cmp_gemm((0, 0), (0, 4), (4, 0))
2125        self.cmp_gemm((0, 0), (0, 0), (0, 0))
2126
2127    def cmp_gemv(self, a_shp, b_shp, c_shp):
2128        av = self.rand(a_shp)
2129        bv = self.rand(*b_shp)
2130        cv = self.rand(c_shp)
2131        l = np.float32(0.2)
2132
2133        a = self.shared(av, 'a')
2134        b = self.shared(bv, 'b')
2135        c = self.shared(cv, 'c')
2136        b_t = self.shared(bv.T, 'b.T')
2137
2138        a_dev = a.get_value(borrow=False, return_internal_type=True)
2139        b_dev = b.get_value(borrow=False, return_internal_type=True)
2140        c_dev = c.get_value(borrow=False, return_internal_type=True)
2141
2142        f_n = theano.function([], [], updates=[(a, (a + l * tensor.dot(b, c)))],
2143                              mode=self.mode)
2144
2145        f_t = theano.function([], [],
2146                              updates=[(a, (a + l * tensor.dot(b_t.T, c)))],
2147                              mode=self.mode)
2148
2149        # Try with all stride patterns, and all transposed pattern
2150        for step_signs in itertools_product((1, -1), repeat=4):
2151            for step in (1, 2):
2152                a_step, b_step1, b_step2, c_step = (s * step
2153                                                    for s in step_signs)
2154
2155                a.set_value(a_dev.copy()[::a_step], borrow=True)
2156                b.set_value(b_dev.copy()[::b_step1, ::b_step2],
2157                            borrow=True)
2158                b_t.set_value(transpose(b_dev.copy())[::b_step2, ::b_step1],
2159                              borrow=True)
2160                c.set_value(c_dev.copy()[::c_step], borrow=True)
2161
2162                a_n = (av[::a_step] +
2163                       l * np.dot(bv[::b_step1, ::b_step2],
2164                                  cv[::c_step]))
2165                f_n()
2166                assert np.allclose(a.get_value(), a_n), (a.get_value(), a_n)
2167
2168                a.set_value(a_dev.copy()[::a_step], borrow=True)
2169                f_t()
2170                assert np.allclose(a.get_value(), a_n), (a.get_value(), a_n)
2171
2172    def test_gemv(self):
2173        self.cmp_gemv(3, (3, 5), 5)
2174        self.cmp_gemv(1, (1, 5), 5)
2175        self.cmp_gemv(3, (3, 1), 1)
2176        self.cmp_gemv(0, (0, 5), 5)
2177        self.cmp_gemv(3, (3, 0), 0)
2178        self.cmp_gemv(0, (0, 1), 1)
2179        self.cmp_gemv(1, (1, 0), 0)
2180        self.cmp_gemv(0, (0, 0), 0)
2181
2182    def cmp_ger(self, a_shp, b_shp, c_shp):
2183        av = self.rand(*a_shp)
2184        bv = self.rand(b_shp)
2185        cv = self.rand(c_shp)
2186        l = np.float32(0.2)
2187
2188        a = self.shared(av, 'a')
2189        b = self.shared(bv, 'b')
2190        c = self.shared(cv, 'c')
2191        a_t = self.shared(av.T, 'a.T')
2192
2193        a_dev = a.get_value(borrow=False, return_internal_type=True)
2194        b_dev = b.get_value(borrow=False, return_internal_type=True)
2195        c_dev = c.get_value(borrow=False, return_internal_type=True)
2196
2197        f_n = theano.function(
2198            [], [],
2199            updates=[(a, (a + l * tensor.outer(b, c)))],
2200            mode=self.mode)
2201
2202        f_t = theano.function(
2203            [], [],
2204            updates=[(a_t, (a_t + l * tensor.outer(b, c).T))],
2205            mode=self.mode)
2206
2207        # Try with all stride patterns, and all transposed patterns
2208        for step_signs in itertools_product((1, -1), repeat=4):
2209            for step in (1, 2):
2210                a_step1, a_step2, b_step, c_step = (s * step
2211                                                    for s in step_signs)
2212
2213                a.set_value(a_dev.copy()[::a_step1, ::a_step2], borrow=True)
2214                a_t.set_value(transpose(a_dev.copy())[::a_step1, ::a_step2],
2215                              borrow=True)
2216                b.set_value(b_dev.copy()[::b_step], borrow=True)
2217                c.set_value(c_dev.copy()[::c_step], borrow=True)
2218
2219                f_n()
2220                n_n = (av[::a_step1, ::a_step2] +
2221                       l * np.outer(bv[::b_step], cv[::c_step]))
2222                assert np.allclose(a.get_value(), n_n), (a.get_value(), n_n)
2223
2224                f_t()
2225                n_t = (av.T[::a_step1, ::a_step2] +
2226                       l * np.outer(bv[::b_step], cv[::c_step]).T)
2227                assert np.allclose(a_t.get_value(), n_t), (a_t.get_value(), n_t)
2228
2229    def test_ger_strides(self):
2230        self.cmp_ger((3, 5), 3, 5)
2231        self.cmp_ger((1, 5), 1, 5)
2232        self.cmp_ger((3, 1), 3, 1)
2233        self.cmp_ger((0, 5), 0, 5)
2234        self.cmp_ger((3, 0), 3, 0)
2235        self.cmp_ger((0, 1), 0, 1)
2236        self.cmp_ger((1, 0), 1, 0)
2237        self.cmp_ger((0, 0), 0, 0)
2238
2239    def test_gemm_non_contiguous(self):
2240        # test_gemm_non_contiguous: Test if GEMM works well with non-contiguous matrices.
2241        aval = np.ones((6, 2))
2242        bval = np.ones((2, 7))
2243        cval = np.arange(7) + np.arange(0, .6, .1)[:, np.newaxis]
2244
2245        a = theano.shared(aval[:3], borrow=True)
2246        b = theano.shared(bval[:, :5], borrow=True)
2247        c = theano.shared(cval[:3, :5], borrow=True)
2248
2249        s = theano.tensor.scalar()
2250        upd_c = s * c + theano.tensor.dot(a, b)
2251        f = theano.function([s], [], updates={c: upd_c})
2252
2253        f(0)
2254        ref_output = np.ones((3, 5)) * 2
2255        unittest_tools.assert_allclose(c.get_value(), ref_output)
2256
2257
2258class test_infer_shape(unittest_tools.InferShapeTester):
2259    def test_dot22(self):
2260        x, y = T.matrices('xy')
2261        self._compile_and_check(
2262            [x, y], [T.blas._dot22(x, y)],
2263            [np.random.random((2, 3)).astype(config.floatX),
2264             np.random.random((3, 4)).astype(config.floatX)],
2265            T.blas.Dot22)
2266
2267    def test_dot22scalar(self):
2268        x, y = T.matrices('xy')
2269        a = T.scalar('a')
2270        self._compile_and_check(
2271            [x, y, a], [T.blas._dot22scalar(x, y, a)],
2272            [np.random.random((2, 3)).astype(config.floatX),
2273             np.random.random((3, 4)).astype(config.floatX),
2274             np.asarray(0.5, dtype=config.floatX)],
2275            T.blas.Dot22Scalar)
2276
2277    def test_gemm(self):
2278        x, y, z = T.matrices('xyz')
2279        a = T.scalar('a')
2280        b = T.scalar('b')
2281        self._compile_and_check(
2282            [x, y, a, z, b], [T.blas.gemm(z, a, x, y, b)],
2283            [np.random.random((2, 3)).astype(config.floatX),
2284             np.random.random((3, 4)).astype(config.floatX),
2285             np.asarray(0.5, dtype=config.floatX),
2286             np.random.random((2, 4)).astype(config.floatX),
2287             np.asarray(0.5, dtype=config.floatX)],
2288            T.blas.Gemm)
2289
2290    def test_gemv(self):
2291        A = T.matrix('A')
2292        x, y = T.vectors('xy')
2293        a = T.scalar('a')
2294        b = T.scalar('b')
2295        self._compile_and_check(
2296            [y, a, A, x, b], [T.blas.gemv(y, a, A, x, b)],
2297            [np.random.random((2,)).astype(config.floatX),
2298             np.asarray(0.5, dtype=config.floatX),
2299             np.random.random((2, 3)).astype(config.floatX),
2300             np.random.random((3,)).astype(config.floatX),
2301             np.asarray(0.5, dtype=config.floatX)],
2302            T.blas.Gemv)
2303
2304    def test_ger(self):
2305        A = T.matrix('A')
2306        x, y = T.vectors('xy')
2307        a = T.scalar('a')
2308        self._compile_and_check(
2309            [A, a, x, y], [T.blas.ger(A, a, x, y)],
2310            [np.random.random((2, 3)).astype(config.floatX),
2311             np.asarray(0.5, dtype=config.floatX),
2312             np.random.random((2,)).astype(config.floatX),
2313             np.random.random((3,)).astype(config.floatX)],
2314            T.blas.Ger)
2315