1from __future__ import absolute_import, print_function, division
2import numpy as np
3
4import theano
5from theano import tensor
6import theano.tests.unittest_tools as utt
7from theano.tensor.nnet.tests import test_blocksparse
8
9from .config import mode_with_gpu, test_ctx_name
10
11from ..type import gpuarray_shared_constructor
12from ..blocksparse import (GpuSparseBlockGemv,
13                           GpuSparseBlockOuter,
14                           gpu_sparse_block_gemv,
15                           gpu_sparse_block_outer)
16
17
18class BlockSparse_Gemv_and_Outer(test_blocksparse.BlockSparse_Gemv_and_Outer):
19    def setUp(self):
20        utt.seed_rng()
21        self.mode = mode_with_gpu.excluding('constant_folding')
22        self.gemv_op = gpu_sparse_block_gemv
23        self.outer_op = gpu_sparse_block_outer
24        self.gemv_class = GpuSparseBlockGemv
25        self.outer_class = GpuSparseBlockOuter
26
27    # This test is temporarily disabled since we disabled the output_merge
28    # and alpha_merge optimizations for blocksparse due to brokeness.
29    # Re-enable when those are re-added.
30    def Xtest_blocksparse_grad_merge(self):
31        b = tensor.fmatrix()
32        h = tensor.ftensor3()
33        iIdx = tensor.lmatrix()
34        oIdx = tensor.lmatrix()
35
36        W_val, h_val, iIdx_val, b_val, oIdx_val = self.gemv_data()
37        W = gpuarray_shared_constructor(W_val, context=test_ctx_name)
38
39        o = gpu_sparse_block_gemv(b.take(oIdx, axis=0), W, h, iIdx, oIdx)
40        gW = theano.grad(o.sum(), W)
41
42        lr = np.asarray(0.05, dtype='float32')
43
44        upd = W - lr * gW
45
46        f1 = theano.function([h, iIdx, b, oIdx], updates=[(W, upd)],
47                             mode=mode_with_gpu)
48
49        # Make sure the lr update was merged.
50        assert isinstance(f1.maker.fgraph.outputs[0].owner.op,
51                          GpuSparseBlockOuter)
52
53        # Exclude the merge optimizations.
54        mode = mode_with_gpu.excluding('local_merge_blocksparse_alpha')
55        mode = mode.excluding('local_merge_blocksparse_output')
56
57        f2 = theano.function([h, iIdx, b, oIdx], updates=[(W, upd)], mode=mode)
58
59        # Make sure the lr update is not merged.
60        assert not isinstance(f2.maker.fgraph.outputs[0].owner.op,
61                              GpuSparseBlockOuter)
62
63        f2(h_val, iIdx_val, b_val, oIdx_val)
64        W_ref = W.get_value()
65
66        # reset the var
67        W.set_value(W_val)
68        f1(h_val, iIdx_val, b_val, oIdx_val)
69        W_opt = W.get_value()
70
71        utt.assert_allclose(W_ref, W_opt)
72