1from chainer.functions.connection import local_convolution_2d
2from chainer import initializers
3from chainer import link
4from chainer import variable
5
6
7def _pair(x):
8    if hasattr(x, '__getitem__'):
9        return x
10    return x, x
11
12
13def _conv_output_length(input_length, filter_size, stride):
14    output_length = input_length - filter_size + 1
15    return output_length
16
17
18class LocalConvolution2D(link.Link):
19
20    """Two-dimensional local convolutional layer.
21
22    This link wraps the :func:`~chainer.functions.local_convolution_2d`
23    function and holds the filter weight and bias array as parameters.
24
25    Args:
26        in_channels (int): Number of channels of input arrays. If either
27            in_channels or in_size is ``None``,
28            parameter initialization will be deferred until the first forward
29            data pass at which time the size will be determined.
30        out_channels (int): Number of channels of output arrays
31        in_size (int or pair of ints): Size of each image channel
32            ``in_size=k`` and ``in_size=(k,k)`` are equivalent. If either
33            in_channels or in_size is ``None``, parameter initialization will
34            be deferred until the first forward data pass when the size will be
35            determined.
36        ksize (int or pair of ints): Size of filters (a.k.a. kernels).
37            ``ksize=k`` and ``ksize=(k, k)`` are equivalent.
38        stride (int or pair of ints): Stride of filter applications.
39            ``stride=s`` and ``stride=(s, s)`` are equivalent.
40        nobias (bool): If ``True``, then this link does not use the bias term.
41        initialW (:ref:`initializer <initializer>`): Initializer to
42            initialize the weight. When it is :class:`numpy.ndarray`,
43            its ``ndim`` should be 6.
44        initial_bias (:ref:`initializer <initializer>`): Initializer to
45            initialize the bias. If ``None``, the bias will be initialized to
46            zero. When it is :class:`numpy.ndarray`, its ``ndim`` should be 3.
47
48    .. seealso::
49       See :func:`chainer.functions.local_convolution_2d`.
50
51    Attributes:
52        W (~chainer.Variable): Weight parameter.
53        b (~chainer.Variable): Bias parameter.
54    """
55
56    def __init__(self, in_channels, out_channels, in_size=None, ksize=None,
57                 stride=1, nobias=False, initialW=None, initial_bias=None,
58                 **kwargs):
59        super(LocalConvolution2D, self).__init__()
60        self.ksize = ksize
61        self.stride = _pair(stride)
62        self.nobias = nobias
63        self.out_channels = out_channels
64        with self.init_scope():
65            W_initializer = initializers._get_initializer(initialW)
66            self.W = variable.Parameter(W_initializer)
67
68            if nobias:
69                self.b = None
70            else:
71                if initial_bias is None:
72                    initial_bias = 0
73                bias_initializer = initializers._get_initializer(initial_bias)
74                self.b = variable.Parameter(bias_initializer)
75
76            if in_channels is not None and in_size is not None:
77                self._initialize_params(in_channels, _pair(in_size))
78
79    def _initialize_params(self, in_channels, in_size):
80        kh, kw = _pair(self.ksize)
81        ih, iw = _pair(in_size)
82        oh = _conv_output_length(ih, kh, self.stride[0])
83        ow = _conv_output_length(iw, kw, self.stride[1])
84        W_shape = (self.out_channels, oh, ow, in_channels, kh, kw)
85        bias_shape = (self.out_channels, oh, ow,)
86        self.W.initialize(W_shape)
87        if not self.nobias:
88            self.b.initialize(bias_shape)
89
90    def forward(self, x):
91        """Applies the local convolution layer.
92
93        Args:
94            x (~chainer.Variable): Input image.
95
96        Returns:
97            ~chainer.Variable: Output of the convolution.
98
99        """
100        if self.W.array is None:
101            self._initialize_params(x.shape[1], x.shape[2:])
102        return local_convolution_2d.local_convolution_2d(
103            x, self.W, self.b, self.stride)
104