1import numpy
2
3from chainer.backends import cuda
4from chainer.functions.connection import bilinear
5from chainer import initializers
6from chainer import link
7from chainer import variable
8
9
10class Bilinear(link.Link):
11
12    """Bilinear layer that performs tensor multiplication.
13
14    Bilinear is a primitive link that wraps the
15    :func:`~chainer.functions.bilinear` functions. It holds parameters ``W``,
16    ``V1``, ``V2``, and ``b`` corresponding to the arguments of
17    :func:`~chainer.functions.bilinear`.
18
19    Args:
20        left_size (int): Dimension of input vector :math:`e^1` (:math:`J`)
21        right_size (int): Dimension of input vector :math:`e^2` (:math:`K`)
22        out_size (int): Dimension of output vector :math:`y` (:math:`L`)
23        nobias (bool): If ``True``, parameters ``V1``, ``V2``, and ``b`` are
24            omitted.
25        initialW (:ref:`initializer <initializer>`): Initializer to
26            initialize the weight. When it is :class:`numpy.ndarray`,
27            its ``ndim`` should be 3.
28        initial_bias (tuple of :ref:`initializer <initializer>`):
29            Initial values of :math:`V^1`, :math:`V^2` and
30            :math:`b`. The length of this argument must be 3.
31            Each element of this tuple must have the shapes of
32            ``(left_size, out_size)``, ``(right_size, out_size)``, and
33            ``(out_size,)``, respectively if they are :class:`numpy.ndarray`.
34            If ``None``, :math:`V^1` and :math:`V^2` are initialized
35            by the default initializer and :math:`b` is set to :math:`0`.
36
37    .. seealso:: See :func:`chainer.functions.bilinear` for details.
38
39    Attributes:
40        W (~chainer.Variable): Bilinear weight parameter.
41        V1 (~chainer.Variable): Linear weight parameter for the first argument.
42        V2 (~chainer.Variable): Linear weight parameter for the second
43            argument.
44        b (~chainer.Variable): Bias parameter.
45
46    """
47
48    def __init__(self, left_size, right_size, out_size, nobias=False,
49                 initialW=None, initial_bias=None):
50        super(Bilinear, self).__init__()
51        self.in_sizes = (left_size, right_size)
52        self.nobias = nobias
53
54        # TODO(Kenta OONO): I do not know appropriate way of
55        # initializing weights in tensor network.
56        # This initialization is a modification of
57        # that of Linear function.
58
59        with self.init_scope():
60            shape = (left_size, right_size, out_size)
61            if isinstance(initialW, (numpy.ndarray, cuda.ndarray)):
62                assert initialW.shape == shape
63            self.W = variable.Parameter(
64                initializers._get_initializer(initialW), shape)
65
66            if not self.nobias:
67                V1_shape = (left_size, out_size)
68                V2_shape = (right_size, out_size)
69                b_shape = (out_size,)
70                if isinstance(initial_bias, tuple):
71                    initialV1, initialV2, initialb = initial_bias
72                    if isinstance(initialV1, (numpy.ndarray, cuda.ndarray)):
73                        assert initialV1.shape == V1_shape
74                    if isinstance(initialV2, (numpy.ndarray, cuda.ndarray)):
75                        assert initialV2.shape == V2_shape
76                    if isinstance(initialb, (numpy.ndarray, cuda.ndarray)):
77                        assert initialb.shape == b_shape
78                    initialV1 = initializers._get_initializer(initialV1)
79                    initialV2 = initializers._get_initializer(initialV2)
80                    initialb = initializers._get_initializer(initialb)
81                elif initial_bias is None:
82                    initialV1 = initializers._get_initializer(None)
83                    initialV2 = initializers._get_initializer(None)
84                    initialb = 0
85                else:
86                    raise ValueError('initial_bias must be tuple or None')
87
88                self.V1 = variable.Parameter(initialV1, V1_shape)
89                self.V2 = variable.Parameter(initialV2, V2_shape)
90                self.b = variable.Parameter(initialb, b_shape)
91
92    def forward(self, e1, e2):
93        """Applies the bilinear function to inputs and the internal parameters.
94
95        Args:
96            e1 (~chainer.Variable): Left input.
97            e2 (~chainer.Variable): Right input.
98
99        Returns:
100            ~chainer.Variable: Output variable.
101
102        """
103        if self.nobias:
104            return bilinear.bilinear(e1, e2, self.W)
105        else:
106            return bilinear.bilinear(e1, e2, self.W, self.V1, self.V2, self.b)
107
108    def zero_grads(self):
109        # Left for backward compatibility
110        self.zerograds()
111