1from __future__ import absolute_import, print_function, division 2import numpy as np 3 4import theano 5from theano import tensor 6import theano.tests.unittest_tools as utt 7from theano.tensor.nnet.tests import test_blocksparse 8 9from .config import mode_with_gpu, test_ctx_name 10 11from ..type import gpuarray_shared_constructor 12from ..blocksparse import (GpuSparseBlockGemv, 13 GpuSparseBlockOuter, 14 gpu_sparse_block_gemv, 15 gpu_sparse_block_outer) 16 17 18class BlockSparse_Gemv_and_Outer(test_blocksparse.BlockSparse_Gemv_and_Outer): 19 def setUp(self): 20 utt.seed_rng() 21 self.mode = mode_with_gpu.excluding('constant_folding') 22 self.gemv_op = gpu_sparse_block_gemv 23 self.outer_op = gpu_sparse_block_outer 24 self.gemv_class = GpuSparseBlockGemv 25 self.outer_class = GpuSparseBlockOuter 26 27 # This test is temporarily disabled since we disabled the output_merge 28 # and alpha_merge optimizations for blocksparse due to brokeness. 29 # Re-enable when those are re-added. 30 def Xtest_blocksparse_grad_merge(self): 31 b = tensor.fmatrix() 32 h = tensor.ftensor3() 33 iIdx = tensor.lmatrix() 34 oIdx = tensor.lmatrix() 35 36 W_val, h_val, iIdx_val, b_val, oIdx_val = self.gemv_data() 37 W = gpuarray_shared_constructor(W_val, context=test_ctx_name) 38 39 o = gpu_sparse_block_gemv(b.take(oIdx, axis=0), W, h, iIdx, oIdx) 40 gW = theano.grad(o.sum(), W) 41 42 lr = np.asarray(0.05, dtype='float32') 43 44 upd = W - lr * gW 45 46 f1 = theano.function([h, iIdx, b, oIdx], updates=[(W, upd)], 47 mode=mode_with_gpu) 48 49 # Make sure the lr update was merged. 50 assert isinstance(f1.maker.fgraph.outputs[0].owner.op, 51 GpuSparseBlockOuter) 52 53 # Exclude the merge optimizations. 54 mode = mode_with_gpu.excluding('local_merge_blocksparse_alpha') 55 mode = mode.excluding('local_merge_blocksparse_output') 56 57 f2 = theano.function([h, iIdx, b, oIdx], updates=[(W, upd)], mode=mode) 58 59 # Make sure the lr update is not merged. 60 assert not isinstance(f2.maker.fgraph.outputs[0].owner.op, 61 GpuSparseBlockOuter) 62 63 f2(h_val, iIdx_val, b_val, oIdx_val) 64 W_ref = W.get_value() 65 66 # reset the var 67 W.set_value(W_val) 68 f1(h_val, iIdx_val, b_val, oIdx_val) 69 W_opt = W.get_value() 70 71 utt.assert_allclose(W_ref, W_opt) 72