1import chainer 2from chainer.functions.array import broadcast 3from chainer.functions.array import reshape 4 5 6def scale(x, y, axis=1): 7 """Elementwise product with broadcasting. 8 9 Computes a elementwise product of two input variables, with the shape of 10 the latter variable broadcasted to match the shape of the former. ``axis`` 11 is the first axis of the first variable along which the second variable is 12 applied. 13 14 The term "broadcasting" here comes from Caffe's scale layer so the 15 "broadcasting" with the following arguments:: 16 17 x : 100 x 3 x 40 x 5 x 6 18 y : 3 x 40 19 axis : 1 20 21 is equivalent to the following numpy broadcasting:: 22 23 x : 100 x 3 x 40 x 5 x 6 24 y : (1 x) 3 x 40 x 1 x 1 25 26 Note that the axis of ``x`` to which we apply ``y`` is specified by the 27 argument ``axis``, whose meaning is different from numpy's ``axis``. 28 29 Args: 30 x (:class:`~chainer.Variable` or :ref:`ndarray`): 31 Input variable to be scaled. 32 y (:class:`~chainer.Variable` or :ref:`ndarray`): 33 Input variable to scale, broadcasted. 34 axis (int): The first axis of ``x`` along which ``y`` is applied. 35 36 Returns: 37 ~chainer.Variable: Output variable. 38 39 """ 40 x_shape = x.shape 41 y_shape = y.shape 42 if chainer.is_debug(): 43 assert x_shape[axis:axis + len(y_shape)] == y_shape 44 y1_shape = tuple([1] * axis + list(y_shape) + 45 [1] * (len(x_shape) - axis - len(y_shape))) 46 y1 = reshape.reshape(y, y1_shape) 47 y2 = broadcast.broadcast_to(y1, x_shape) 48 return x * y2 49