1import chainer
2from chainer.functions.array import broadcast
3from chainer.functions.array import reshape
4
5
6def scale(x, y, axis=1):
7    """Elementwise product with broadcasting.
8
9    Computes a elementwise product of two input variables, with the shape of
10    the latter variable broadcasted to match the shape of the former. ``axis``
11    is the first axis of the first variable along which the second variable is
12    applied.
13
14    The term "broadcasting" here comes from Caffe's scale layer so the
15    "broadcasting" with the following arguments::
16
17           x : 100 x 3 x 40 x 5 x 6
18           y : 3 x 40
19        axis : 1
20
21    is equivalent to the following numpy broadcasting::
22
23        x : 100 x  3 x 40 x 5 x 6
24        y :  (1 x) 3 x 40 x 1 x 1
25
26    Note that the axis of ``x`` to which we apply ``y`` is specified by the
27    argument ``axis``, whose meaning is different from numpy's ``axis``.
28
29    Args:
30        x (:class:`~chainer.Variable` or :ref:`ndarray`):
31            Input variable to be scaled.
32        y (:class:`~chainer.Variable` or :ref:`ndarray`):
33            Input variable to scale, broadcasted.
34        axis (int): The first axis of ``x`` along which ``y`` is applied.
35
36    Returns:
37        ~chainer.Variable: Output variable.
38
39    """
40    x_shape = x.shape
41    y_shape = y.shape
42    if chainer.is_debug():
43        assert x_shape[axis:axis + len(y_shape)] == y_shape
44    y1_shape = tuple([1] * axis + list(y_shape) +
45                     [1] * (len(x_shape) - axis - len(y_shape)))
46    y1 = reshape.reshape(y, y1_shape)
47    y2 = broadcast.broadcast_to(y1, x_shape)
48    return x * y2
49