1from __future__ import absolute_import, print_function, division
2import time
3
4from nose.plugins.skip import SkipTest
5import numpy as np
6import theano
7import theano.tensor as T
8from theano.tests import unittest_tools as utt
9from theano.tensor.nnet import conv
10from theano.tensor.basic import _allclose, NotScalarConstantError
11from theano.tests.unittest_tools import attr
12
13
14class TestConv2D(utt.InferShapeTester):
15    # This class contains tests for the legacy 2d convolution,
16    # but will also be inherited from for other implementations
17    mode = None
18    dtype = theano.config.floatX
19    # This will be set to the appropriate function in the inherited classes.
20    # The call to `staticmethod` is necessary to prevent Python from passing
21    # `self` as the first argument.
22    conv2d = staticmethod(conv.conv2d)
23
24    def setUp(self):
25        super(TestConv2D, self).setUp()
26        self.input = T.tensor4('input', dtype=self.dtype)
27        self.input.name = 'default_V'
28        self.filters = T.tensor4('filters', dtype=self.dtype)
29        self.filters.name = 'default_filters'
30        if not conv.imported_scipy_signal and theano.config.cxx == "":
31            raise SkipTest("conv2d tests need SciPy or a c++ compiler")
32
33    def validate(self, image_shape, filter_shape,
34                 border_mode='valid', subsample=(1, 1),
35                 N_image_shape=None, N_filter_shape=None,
36                 input=None, filters=None,
37                 unroll_batch=None, unroll_kern=None, unroll_patch=None,
38                 verify_grad=True, should_raise=False):
39        """
40        :param image_shape: The constant shape info passed to conv2d.
41        :param filter_shape: The constant shape info passed to conv2d.
42
43        :param N_image_shape: None(default to image_shape) or tuple of
44                              4 elements with the shape of the input image
45
46        :param N_filter_shape: None(default to filter_shape) or tuple
47                               of 4 elements with the shape of the
48                               input filter
49
50        """
51        if N_image_shape is None:
52            N_image_shape = [T.get_scalar_constant_value(
53                T.as_tensor_variable(x)) for x in image_shape]
54        if N_filter_shape is None:
55            N_filter_shape = [T.get_scalar_constant_value(
56                T.as_tensor_variable(x)) for x in filter_shape]
57
58        if input is None:
59            input = self.input
60        if not filters:
61            filters = self.filters
62
63        # THEANO IMPLEMENTATION
64
65        # we create a symbolic function so that verify_grad can work
66        def sym_conv2d(input, filters):
67            # define theano graph and function
68            input.name = 'input'
69            filters.name = 'filters'
70            rval = conv.conv2d(
71                input, filters, image_shape, filter_shape,
72                border_mode, subsample, unroll_batch=unroll_batch,
73                unroll_kern=unroll_kern, unroll_patch=unroll_patch)
74            rval.name = 'conv_output'
75            return rval
76
77        output = sym_conv2d(input, filters)
78        output.name = 'conv2d(%s,%s)' % (input.name, filters.name)
79        theano_conv = theano.function([input, filters], output, mode=self.mode)
80
81        # initialize input and compute result
82        image_data = np.random.random(N_image_shape).astype(self.dtype)
83        filter_data = np.random.random(N_filter_shape).astype(self.dtype)
84        try:
85            theano_output = theano_conv(image_data, filter_data)
86        except ValueError:
87            if not should_raise:
88                raise
89            return
90        else:
91            if should_raise:
92                raise Exception(
93                    "ConvOp should have generated an error")
94
95        # REFERENCE IMPLEMENTATION
96        s = 1.
97        orig_image_data = image_data
98        if border_mode != 'full':
99            s = -1.
100        out_shape2d = np.array(N_image_shape[-2:]) +\
101            s * np.array(N_filter_shape[-2:]) - s
102        out_shape2d = np.ceil(out_shape2d / np.array(subsample))
103        # avoid numpy deprecation
104        out_shape2d = out_shape2d.astype('int32')
105        out_shape = (N_image_shape[0], N_filter_shape[0]) + tuple(out_shape2d)
106        ref_output = np.zeros(out_shape)
107
108        # loop over output feature maps
109        ref_output.fill(0)
110        if border_mode == 'full':
111            image_data2 = np.zeros((N_image_shape[0], N_image_shape[1],
112                                   N_image_shape[2] + 2 * N_filter_shape[2] - 2,
113                                   N_image_shape[3] + 2 * N_filter_shape[3] - 2))
114            image_data2[
115                :, :, N_filter_shape[2] - 1:N_filter_shape[2] - 1 + N_image_shape[2],
116                N_filter_shape[3] - 1:N_filter_shape[3] - 1 + N_image_shape[3]] = image_data
117            image_data = image_data2
118            N_image_shape = image_data.shape
119        for bb in range(N_image_shape[0]):
120            for nn in range(N_filter_shape[0]):
121                for im0 in range(N_image_shape[1]):
122                    filter2d = filter_data[nn, im0, :, :]
123                    image2d = image_data[bb, im0, :, :]
124                    for row in range(ref_output.shape[2]):
125                        irow = row * subsample[0]  # image row
126                        for col in range(ref_output.shape[3]):
127                            icol = col * subsample[1]  # image col
128                            ref_output[bb, nn, row, col] += (image2d[
129                                irow:irow + N_filter_shape[2],
130                                icol:icol + N_filter_shape[3]] * filter2d[::-1, ::-1]
131                            ).sum()
132
133        self.assertTrue(_allclose(theano_output, ref_output))
134
135        # TEST GRADIENT
136        if verify_grad:
137            utt.verify_grad(sym_conv2d, [orig_image_data, filter_data])
138
139    def test_basic1(self):
140        # Tests that basic convolutions work for odd and even
141        # dimensions of image and filter shapes, as well as rectangular
142        # images and filters.
143
144        self.validate((2, 2, 3, 3), (2, 2, 2, 2), 'valid', verify_grad=False)
145
146    def test_basic(self):
147        # Tests that basic convolutions work for odd and even
148        # dimensions of image and filter shapes, as well as rectangular
149        # images and filters.
150
151        self.validate((3, 2, 8, 8), (4, 2, 5, 5), 'valid', verify_grad=False)
152        self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid')
153        self.validate((3, 2, 7, 5), (5, 2, 3, 2), 'valid', verify_grad=False)
154        self.validate((3, 2, 8, 8), (4, 2, 5, 5), 'full', verify_grad=False)
155        self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full')
156        # test filter same size as input
157
158    def test_uint_image_shape_datatype(self):
159        # Tests for uint datatype in image_shape.
160
161        self.validate((2, 2, 3, np.uint8(3)), (3, 2, 3, 3), 'valid', verify_grad=False)
162        self.validate((np.uint16(2), 2, 3, 3), (3, 2, 3, 3), 'valid', verify_grad=False)
163        self.validate((2, np.uint32(2), 3, 3), (3, 2, 3, 3), 'valid', verify_grad=False)
164
165    def test_uint_filter_shape_datatype(self):
166        # Tests for uint datatype in filter_shape
167
168        self.validate((3, 2, 3, 3), (2, 2, 3, np.uint8(3)), 'valid', verify_grad=False)
169        self.validate((3, 2, 3, 3), (np.uint16(2), 2, 3, 3), 'valid', verify_grad=False)
170        self.validate((3, 2, 3, 3), (2, np.uint32(2), 3, 3), 'valid', verify_grad=False)
171
172    def test_img_kernel_same_shape(self):
173        self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'full')
174        self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid')
175
176    def test_unroll_patch_true(self):
177        # Test basic convs with True.
178
179        self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=True)
180        self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=True)
181        self.validate(
182            (3, 2, 3, 3), (4, 2, 3, 3), 'valid',
183            unroll_patch=True, verify_grad=False)
184
185    def test_unroll_patch_false(self):
186        # Test basic convs with False.
187
188        self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=False)
189        self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=False)
190        self.validate(
191            (3, 2, 3, 3), (4, 2, 3, 3), 'valid',
192            unroll_patch=False, verify_grad=False)
193
194    def test_unroll_patch_true_fail(self):
195        # Test basic convs with True.
196
197        self.validate(
198            (3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=True,
199            N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
200            should_raise=True)
201        self.validate(
202            (3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=True,
203            N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
204            should_raise=True)
205        self.validate(
206            (3, 2, 3, 3), (4, 2, 3, 3), 'valid', unroll_patch=True,
207            N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
208            should_raise=True)
209
210    def test_unroll_special(self):
211        # (unroll_kern, unroll_batch) in (0,1),(1,0) is special case.
212
213        self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid', unroll_batch=1)
214
215    def test_unroll_batch(self):
216        # Test mini-batch unrolling for various legal values.
217
218        # mini-batch of size 6 is multiple of 2 and 3. Should work.
219        self.validate(
220            (6, 2, 3, 3), (3, 2, 2, 2), 'valid',
221            unroll_batch=2, verify_grad=False)
222        self.validate(
223            (6, 2, 3, 3), (3, 2, 2, 2), 'valid',
224            unroll_batch=3, verify_grad=False)
225
226    def test_unroll_kern(self):
227        # Test kernel unrolling for various legal values.
228
229        # 6 filters is a multiple of 2 and 3. Should work.
230        self.validate(
231            (2, 3, 3, 3), (6, 3, 2, 2), 'valid', unroll_kern=2,
232            verify_grad=False)
233        self.validate(
234            (2, 3, 3, 3), (6, 3, 2, 2), 'valid', unroll_kern=3,
235            verify_grad=False)
236
237    def test_unroll_batch_kern(self):
238        # Test mini-batch unrolling with kernel unrolling for various
239        # legal values.
240
241        # mini-batch of size 6 is multiple of 2 and 3. Should work.
242        self.validate(
243            (6, 2, 3, 3), (3, 2, 2, 2), 'valid',
244            unroll_batch=2, unroll_kern=3, verify_grad=False)
245        self.validate(
246            (6, 2, 3, 3), (3, 2, 2, 2), 'valid',
247            unroll_batch=3, unroll_kern=3, verify_grad=False)
248        # 6 filters is a multiple of 2 and 3. Should work.
249        self.validate(
250            (2, 3, 3, 3), (6, 3, 2, 2), 'valid',
251            unroll_batch=2, unroll_kern=2, verify_grad=False)
252        self.validate(
253            (2, 3, 3, 3), (6, 3, 2, 2), 'valid',
254            unroll_batch=2, unroll_kern=3, verify_grad=False)
255
256    def test_unroll_batch_kern_fail(self):
257        # Test mini-batch unrolling with kernel unrolling for various
258        # legal values, but pass bad input.  All those test must
259        # generate errors
260
261        # mini-batch of size 6 is multiple of 2 and 3. Should work.
262        self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
263                      unroll_batch=2, unroll_kern=3,
264                      N_image_shape=(7, 2, 3, 3), N_filter_shape=(3, 2, 2, 2),
265                      should_raise=True)
266        self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
267                      unroll_batch=3, unroll_kern=3,
268                      N_image_shape=(6, 2, 3, 3), N_filter_shape=(4, 2, 2, 2),
269                      should_raise=True)
270        self.validate(
271            (2, 3, 3, 3), (6, 3, 2, 2), 'valid',
272            unroll_batch=2, unroll_kern=2,
273            N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
274            should_raise=True)
275        self.validate(
276            (2, 3, 3, 3), (6, 3, 2, 2), 'valid',
277            unroll_batch=2, unroll_kern=3,
278            N_image_shape=(2, 3, 3, 3), N_filter_shape=(5, 3, 2, 2),
279            should_raise=True)
280
281    @attr('slow')
282    def test_subsample(self):
283        # Tests convolution where subsampling != (1,1)
284        self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', subsample=(2, 2))
285
286        # Fails as of 2012-07-11
287        self.assertRaises(NotImplementedError, self.validate, (1, 1, 6, 6), (1, 1, 3, 3), 'full', subsample=(3, 3))
288
289        # Fails as of 2017-08-10
290        self.assertRaises(NotImplementedError, self.validate, (3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 2))
291        self.assertRaises(NotImplementedError, self.validate, (3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 1))
292        self.assertRaises(NotImplementedError, self.validate, (1, 1, 6, 6), (1, 1, 3, 3), 'valid', subsample=(3, 3))
293
294    def test_shape_Constant_tensor(self):
295        # Tests convolution where the {image,filter}_shape is a Constant tensor.
296
297        as_t = T.as_tensor_variable
298        self.validate(
299            (as_t(3), as_t(2), as_t(7), as_t(5)), (5, 2, 2, 3), 'valid')
300        self.validate(as_t([3, 2, 7, 5]), (5, 2, 2, 3), 'valid')
301        self.validate(as_t((3, 2, 7, 5)), (5, 2, 2, 3), 'valid')
302        self.validate(
303            (3, 2, 7, 5), (
304                as_t(5), as_t(2), as_t(2),
305                as_t(3)), 'valid')
306        self.validate((3, 2, 7, 5), as_t([5, 2, 2, 3]), 'valid')
307        self.validate((3, 2, 7, 5), as_t((5, 2, 2, 3)), 'valid')
308        self.validate(as_t([3, 2, 7, 5]), as_t([5, 2, 2, 3]), 'full')
309
310    def test_invalid_filter_shape(self):
311        # Tests scenario where filter_shape[1] != input_shape[1]
312
313        self.assertRaises(AssertionError, self.validate,
314                          (3, 2, 8, 8), (4, 3, 5, 5),
315                          'valid')
316
317    @attr('slow')
318    def test_invalid_input_shape(self):
319        # Tests that when the shape gived at build time is not the same as
320        # run time we raise an error
321
322        for unroll_batch in [None, 1, 3]:
323            for unroll_kern in [None, 2, 4]:
324                for unroll_patch in [None, True, False]:
325                    for mode in ['valid', 'full']:
326                        self.assertRaises(ValueError, self.validate,
327                                          (3, 2, 8, 8), (4, 2, 5, 5),
328                                          mode, N_image_shape=(2, 2, 8, 8),
329                                          unroll_batch=unroll_batch,
330                                          unroll_kern=unroll_kern,
331                                          unroll_patch=unroll_patch)
332                        self.assertRaises(ValueError, self.validate,
333                                          (3, 2, 8, 8), (4, 2, 5, 5),
334                                          mode, N_image_shape=(3, 1, 8, 8),
335                                          unroll_batch=unroll_batch,
336                                          unroll_kern=unroll_kern,
337                                          unroll_patch=unroll_patch)
338                        self.assertRaises(ValueError, self.validate,
339                                          (3, 2, 8, 8), (4, 2, 5, 5),
340                                          mode, N_image_shape=(3, 2, 7, 8),
341                                          unroll_batch=unroll_batch,
342                                          unroll_kern=unroll_kern,
343                                          unroll_patch=unroll_patch)
344                        self.assertRaises(ValueError, self.validate,
345                                          (3, 2, 8, 8), (4, 2, 5, 5),
346                                          mode, N_image_shape=(3, 2, 8, 7),
347                                          unroll_batch=unroll_batch,
348                                          unroll_kern=unroll_kern,
349                                          unroll_patch=unroll_patch)
350
351                        self.assertRaises(ValueError, self.validate,
352                                          (3, 2, 8, 8), (4, 2, 5, 5),
353                                          mode, N_filter_shape=(3, 2, 5, 5),
354                                          unroll_batch=unroll_batch,
355                                          unroll_kern=unroll_kern,
356                                          unroll_patch=unroll_patch)
357                        self.assertRaises(ValueError, self.validate,
358                                          (3, 2, 8, 8), (4, 2, 5, 5),
359                                          mode, N_filter_shape=(4, 1, 5, 5),
360                                          unroll_batch=unroll_batch,
361                                          unroll_kern=unroll_kern,
362                                          unroll_patch=unroll_patch)
363                        self.assertRaises(ValueError, self.validate,
364                                          (3, 2, 8, 8), (4, 2, 5, 5),
365                                          mode, N_filter_shape=(4, 2, 6, 5),
366                                          unroll_batch=unroll_batch,
367                                          unroll_kern=unroll_kern,
368                                          unroll_patch=unroll_patch)
369                        self.assertRaises(ValueError, self.validate,
370                                          (3, 2, 8, 8), (4, 2, 5, 5),
371                                          mode, N_filter_shape=(4, 2, 5, 6),
372                                          unroll_batch=unroll_batch,
373                                          unroll_kern=unroll_kern,
374                                          unroll_patch=unroll_patch)
375
376    def test_missing_info(self):
377        # Test convolutions for various pieces of missing info.
378
379        self.validate(None, None,
380                      N_image_shape=(3, 2, 8, 8),
381                      N_filter_shape=(4, 2, 5, 5))
382        self.validate((3, 2, None, None), None,
383                      N_image_shape=(3, 2, 8, 8),
384                      N_filter_shape=(4, 2, 5, 5))
385        self.validate((None, 2, None, None), (None, 2, 5, 5),
386                      N_image_shape=(3, 2, 8, 8),
387                      N_filter_shape=(4, 2, 5, 5))
388        self.validate((3, 2, 8, 8), (4, 2, None, 5),
389                      N_image_shape=(3, 2, 8, 8),
390                      N_filter_shape=(4, 2, 5, 5))
391        self.validate((3, 2, 8, 8), (4, 2, 5, None),
392                      N_image_shape=(3, 2, 8, 8),
393                      N_filter_shape=(4, 2, 5, 5))
394
395    def test_wrong_info(self):
396        # Test convolutions when we don't give a constant as shape information
397
398        i = theano.scalar.basic.int32()
399        self.assertRaises(NotScalarConstantError, self.validate,
400                          (3, 2, 8, i), (4, 2, 5, 5),
401                          N_image_shape=(3, 2, 8, 8),
402                          N_filter_shape=(4, 2, 5, 5))
403        self.assertRaises(NotScalarConstantError, self.validate,
404                          (3, 2, 8, 8), (4, 2, 5, i),
405                          N_image_shape=(3, 2, 8, 8),
406                          N_filter_shape=(4, 2, 5, 5))
407
408    def test_full_mode(self):
409        # Tests basic convolution in full mode and case where filter
410        # is larger than the input image.
411
412        self.validate((3, 2, 5, 5), (4, 2, 8, 8), 'full')
413
414        def f():
415            self.validate((3, 2, 5, 5), (4, 2, 8, 8), 'valid')
416        self.assertRaises(Exception, f)
417
418    def test_wrong_input(self):
419        # Make sure errors are raised when image and kernel are not 4D tensors
420
421        self.assertRaises(Exception, self.validate, (3, 2, 8, 8), (4, 2, 5, 5),
422                          'valid', input=T.dmatrix())
423        self.assertRaises(Exception, self.validate, (3, 2, 8, 8), (4, 2, 5, 5),
424                          'valid', filters=T.dvector())
425        self.assertRaises(Exception, self.validate, (3, 2, 8, 8), (4, 2, 5, 5),
426                          'valid', input=T.dtensor3())
427
428    def test_gcc_crash(self):
429        # gcc 4.3.0 20080428 (Red Hat 4.3.0-8)
430        #
431        # crashed in this following case. I changed the c code to don't hit
432        # gcc bug. So it should not crash anymore
433
434        self.validate((1, 10, 213, 129), (46, 10, 212, 1), 'valid',
435                      verify_grad=False)
436
437    def speed(self):
438        n_calls = 20000
439        print("n_calls", n_calls)
440        for border_mode in ['valid', 'full']:
441            print()
442            print(border_mode)
443            for openmp in [False, True]:
444                print("OpenMP", openmp)
445                image_shapes = [
446                    (1, 5, 6, 6),
447                    (10, 5, 6, 6)
448                    # (10, 10, 16, 16),
449                    # (10, 10, 32, 32)]
450                ]
451                print("image_shape", image_shapes)
452                for image_shape in image_shapes:
453                    filter_shapes = [(1, 5, 4, 4), (2, 5, 4, 4), (5, 5, 4, 4)]
454                    print("filter_shapes", filter_shapes)
455                    for filter_shape in filter_shapes:
456
457                        input = theano.shared(np.random.random(image_shape))
458                        filters = theano.shared(np.random.random(filter_shape))
459
460                        output = self.conv2d(
461                            input, filters,
462                            image_shape, filter_shape,
463                            border_mode,
464                            unroll_patch=True,
465                            openmp=openmp)
466                        mode = theano.Mode(linker=theano.gof.vm.VM_Linker(
467                            allow_gc=False,
468                            use_cloop=True))
469                        theano_conv = theano.function([], output, mode=mode)
470                        t1 = time.time()
471                        theano_conv.fn(n_calls=n_calls)
472                        t2 = time.time()
473                        print(t2 - t1, end=' ')
474                    print()
475
476    def test_infer_shape(self):
477        # Note: infer_shape is incomplete and thus input and filter shapes
478        # must be provided explicitly
479
480        def rand(*shape):
481            r = np.asarray(np.random.rand(*shape), dtype='float64')
482            return r * 2 - 1
483
484        adtens = T.dtensor4()
485        bdtens = T.dtensor4()
486        aivec_val = [4, 5, 6, 3]
487        bivec_val = [7, 5, 3, 2]
488        adtens_val = rand(*aivec_val)
489        bdtens_val = rand(*bivec_val)
490        self._compile_and_check(
491            [adtens, bdtens],
492            [self.conv2d(
493                adtens, bdtens, aivec_val, bivec_val,
494                border_mode='valid')],
495            [adtens_val, bdtens_val], conv.ConvOp,
496            excluding=['conv_gemm'])
497
498        self._compile_and_check(
499            [adtens, bdtens],
500            [self.conv2d(
501                adtens, bdtens, aivec_val, bivec_val,
502                border_mode='full')],
503            [adtens_val, bdtens_val], conv.ConvOp,
504            excluding=['conv_gemm'])
505
506        aivec_val = [6, 2, 8, 3]
507        bivec_val = [4, 2, 5, 3]
508        adtens_val = rand(*aivec_val)
509        bdtens_val = rand(*bivec_val)
510        self._compile_and_check(
511            [adtens, bdtens],
512            [self.conv2d(
513                adtens, bdtens, aivec_val, bivec_val,
514                border_mode='valid')],
515            [adtens_val, bdtens_val], conv.ConvOp,
516            excluding=['conv_gemm'])
517
518        self._compile_and_check(
519            [adtens, bdtens],
520            [self.conv2d(
521                adtens, bdtens, aivec_val, bivec_val,
522                border_mode='full')],
523            [adtens_val, bdtens_val], conv.ConvOp,
524            excluding=['conv_gemm'])
525
526        aivec_val = [3, 6, 7, 5]
527        bivec_val = [5, 6, 3, 2]
528        adtens_val = rand(*aivec_val)
529        bdtens_val = rand(*bivec_val)
530        self._compile_and_check(
531            [adtens, bdtens],
532            [self.conv2d(
533                adtens, bdtens, aivec_val, bivec_val,
534                border_mode='valid')],
535            [adtens_val, bdtens_val], conv.ConvOp,
536            excluding=['conv_gemm'])
537
538        self._compile_and_check(
539            [adtens, bdtens],
540            [self.conv2d(
541                adtens, bdtens, aivec_val, bivec_val,
542                border_mode='full')],
543            [adtens_val, bdtens_val], conv.ConvOp,
544            excluding=['conv_gemm'])
545
546        aivec_val = [3, 6, 7, 5]
547        bivec_val = [5, 6, 2, 3]
548        adtens_val = rand(*aivec_val)
549        bdtens_val = rand(*bivec_val)
550        self._compile_and_check(
551            [adtens, bdtens],
552            [self.conv2d(
553                adtens, bdtens, aivec_val, bivec_val,
554                border_mode='valid')],
555            [adtens_val, bdtens_val], conv.ConvOp,
556            excluding=['conv_gemm'])
557
558        self._compile_and_check(
559            [adtens, bdtens],
560            [self.conv2d(
561                adtens, bdtens, aivec_val, bivec_val,
562                border_mode='full')],
563            [adtens_val, bdtens_val], conv.ConvOp,
564            excluding=['conv_gemm'])
565
566        aivec_val = [5, 2, 4, 3]
567        bivec_val = [6, 2, 4, 3]
568        adtens_val = rand(*aivec_val)
569        bdtens_val = rand(*bivec_val)
570        self._compile_and_check(
571            [adtens, bdtens],
572            [self.conv2d(
573                adtens, bdtens, aivec_val, bivec_val,
574                border_mode='valid')],
575            [adtens_val, bdtens_val], conv.ConvOp,
576            excluding=['conv_gemm'])
577
578        self._compile_and_check(
579            [adtens, bdtens],
580            [self.conv2d(
581                adtens, bdtens, aivec_val, bivec_val,
582                border_mode='full')],
583            [adtens_val, bdtens_val], conv.ConvOp,
584            excluding=['conv_gemm'])
585
586
587# Test that broadcasting of gradients works correctly when using the
588# nnet.conv2d() interface. This was reported in #3763, and uses the example
589# code from that ticket.
590
591
592def test_broadcast_grad():
593    # rng = numpy.random.RandomState(utt.fetch_seed())
594    x1 = T.tensor4('x')
595    # x1_data = rng.randn(1, 1, 300, 300)
596    sigma = T.scalar('sigma')
597    # sigma_data = 20
598    window_radius = 3
599
600    filter_1d = T.arange(-window_radius, window_radius + 1)
601    filter_1d = filter_1d.astype(theano.config.floatX)
602    filter_1d = T.exp(-0.5 * filter_1d**2 / sigma ** 2)
603    filter_1d = filter_1d / filter_1d.sum()
604
605    filter_W = filter_1d.dimshuffle(['x', 'x', 0, 'x'])
606
607    y = theano.tensor.nnet.conv2d(x1, filter_W, border_mode='full',
608                                  filter_shape=[1, 1, None, None])
609    theano.grad(y.sum(), sigma)
610
611
612if __name__ == '__main__':
613
614    t = TestConv2D('setUp')
615    t.setUp()
616    t.test_infer_shape()
617