1import numpy
2
3import chainer
4from chainer.functions.activation import maxout
5from chainer import initializer
6from chainer import link
7from chainer.links.connection import linear
8
9
10class Maxout(link.Chain):
11    """Fully-connected maxout layer.
12
13    Let ``M``, ``P`` and ``N`` be an input dimension, a pool size,
14    and an output dimension, respectively.
15    For an input vector :math:`x` of size ``M``, it computes
16
17    .. math::
18
19      Y_{i} = \\mathrm{max}_{j} (W_{ij\\cdot}x + b_{ij}).
20
21    Here :math:`W` is a weight tensor of shape ``(M, P, N)``,
22    :math:`b` an  optional bias vector of shape ``(M, P)``
23    and :math:`W_{ij\\cdot}` is a sub-vector extracted from
24    :math:`W` by fixing first and second dimensions to
25    :math:`i` and :math:`j`, respectively.
26    Minibatch dimension is omitted in the above equation.
27
28    As for the actual implementation, this chain has a
29    Linear link with a ``(M * P, N)`` weight matrix and
30    an optional ``M * P`` dimensional bias vector.
31
32    Args:
33        in_size (int): Dimension of input vectors.
34        out_size (int): Dimension of output vectors.
35        pool_size (int): Number of channels.
36        initialW (:ref:`initializer <initializer>`): Initializer to
37            initialize the weight. When it is :class:`numpy.ndarray`,
38            its ``ndim`` should be 3.
39        initial_bias (:ref:`initializer <initializer>`): Initializer to
40            initialize the bias. If ``None``, the bias is omitted.
41            When it is :class:`numpy.ndarray`, its ``ndim`` should be 2.
42
43    Attributes:
44        linear (~chainer.Link): The Linear link that performs
45            affine transformation.
46
47    .. seealso:: :func:`~chainer.functions.maxout`
48
49    .. seealso::
50         Goodfellow, I., Warde-farley, D., Mirza, M.,
51         Courville, A., & Bengio, Y. (2013).
52         Maxout Networks. In Proceedings of the 30th International
53         Conference on Machine Learning (ICML-13) (pp. 1319-1327).
54         `URL <http://jmlr.org/proceedings/papers/v28/goodfellow13.html>`_
55    """
56
57    def __init__(self, in_size, out_size, pool_size,
58                 initialW=None, initial_bias=0):
59        super(Maxout, self).__init__()
60
61        linear_out_size = out_size * pool_size
62
63        if initialW is None or \
64           numpy.isscalar(initialW) or \
65           isinstance(initialW, initializer.Initializer):
66            pass
67        elif isinstance(initialW, chainer.get_array_types()):
68            if initialW.ndim != 3:
69                raise ValueError('initialW.ndim should be 3')
70            initialW = initialW.reshape(linear_out_size, in_size)
71        elif callable(initialW):
72            initialW_orig = initialW
73
74            def initialW(array):
75                array.shape = (out_size, pool_size, in_size)
76                initialW_orig(array)
77                array.shape = (linear_out_size, in_size)
78
79        if initial_bias is None or \
80           numpy.isscalar(initial_bias) or \
81           isinstance(initial_bias, initializer.Initializer):
82            pass
83        elif isinstance(initial_bias, chainer.get_array_types()):
84            if initial_bias.ndim != 2:
85                raise ValueError('initial_bias.ndim should be 2')
86            initial_bias = initial_bias.reshape(linear_out_size)
87        elif callable(initial_bias):
88            initial_bias_orig = initial_bias
89
90            def initial_bias(array):
91                array.shape = (out_size, pool_size)
92                initial_bias_orig(array)
93                array.shape = linear_out_size,
94
95        with self.init_scope():
96            self.linear = linear.Linear(
97                in_size, linear_out_size,
98                nobias=initial_bias is None, initialW=initialW,
99                initial_bias=initial_bias)
100
101        self.out_size = out_size
102        self.pool_size = pool_size
103
104    def forward(self, x):
105        """Applies the maxout layer.
106
107        Args:
108            x (~chainer.Variable): Batch of input vectors.
109
110        Returns:
111            ~chainer.Variable: Output of the maxout layer.
112        """
113        y = self.linear(x)
114        return maxout.maxout(y, self.pool_size)
115