1import collections
2import os
3import sys
4import warnings
5
6import numpy
7try:
8    from PIL import Image
9    available = True
10except ImportError as e:
11    available = False
12    _import_error = e
13
14import chainer
15from chainer.dataset.convert import concat_examples
16from chainer.dataset import download
17from chainer import function
18from chainer.functions.activation.relu import relu
19from chainer.functions.activation.softmax import softmax
20from chainer.functions.array.reshape import reshape
21from chainer.functions.math.sum import sum
22from chainer.functions.pooling.average_pooling_2d import average_pooling_2d
23from chainer.functions.pooling.max_pooling_nd import max_pooling_2d
24from chainer.initializers import constant
25from chainer.initializers import normal
26from chainer import link
27from chainer.links.connection.convolution_2d import Convolution2D
28from chainer.links.connection.linear import Linear
29from chainer.links.normalization.batch_normalization import BatchNormalization
30from chainer.serializers import npz
31from chainer.utils import argument
32from chainer.utils import imgproc
33from chainer.variable import Variable
34
35
36class ResNetLayers(link.Chain):
37
38    """A pre-trained CNN model provided by MSRA.
39
40    When you specify the path of the pre-trained chainer model serialized as
41    a ``.npz`` file in the constructor, this chain model automatically
42    initializes all the parameters with it.
43    This model would be useful when you want to extract a semantic feature
44    vector per image, or fine-tune the model on a different dataset.
45    Note that unlike ``VGG16Layers``, it does not automatically download a
46    pre-trained caffemodel. This caffemodel can be downloaded at
47    `GitHub <https://github.com/KaimingHe/deep-residual-networks>`_.
48
49    If you want to manually convert the pre-trained caffemodel to a chainer
50    model that can be specified in the constructor,
51    please use ``convert_caffemodel_to_npz`` classmethod instead.
52
53    See: K. He et. al., `Deep Residual Learning for Image Recognition
54    <https://arxiv.org/abs/1512.03385>`_
55
56    Args:
57        pretrained_model (str): the destination of the pre-trained
58            chainer model serialized as a ``.npz`` file.
59            If this argument is specified as ``auto``,
60            it automatically loads and converts the caffemodel from
61            ``$CHAINER_DATASET_ROOT/pfnet/chainer/models/ResNet-{n-layers}-model.caffemodel``,
62            where ``$CHAINER_DATASET_ROOT`` is set as
63            ``$HOME/.chainer/dataset`` unless you specify another value
64            by modifying the environment variable and {n_layers} is replaced
65            with the specified number of layers given as the first argument to
66            this constructor. Note that in this case the converted chainer
67            model is stored on the same directory and automatically used from
68            the next time.
69            If this argument is specified as ``None``, all the parameters
70            are not initialized by the pre-trained model, but the default
71            initializer used in the original paper, i.e.,
72            ``chainer.initializers.HeNormal(scale=1.0)``.
73        n_layers (int): The number of layers of this model. It should be either
74            50, 101, or 152.
75        downsample_fb (bool): If this argument is specified as ``False``,
76            it performs downsampling by placing stride 2
77            on the 1x1 convolutional layers (the original MSRA ResNet).
78            If this argument is specified as ``True``, it performs downsampling
79            by placing stride 2 on the 3x3 convolutional layers
80            (Facebook ResNet).
81
82    Attributes:
83        available_layers (list of str): The list of available layer names
84            used by ``forward`` and ``extract`` methods.
85
86    """
87
88    def __init__(self, pretrained_model, n_layers, downsample_fb=False):
89        super(ResNetLayers, self).__init__()
90
91        if pretrained_model:
92            # As a sampling process is time-consuming,
93            # we employ a zero initializer for faster computation.
94            conv_kwargs = {'initialW': constant.Zero()}
95        else:
96            # employ default initializers used in the original paper
97            conv_kwargs = {'initialW': normal.HeNormal(scale=1.0)}
98
99        kwargs = conv_kwargs.copy()
100        kwargs['downsample_fb'] = downsample_fb
101
102        if n_layers == 50:
103            block = [3, 4, 6, 3]
104        elif n_layers == 101:
105            block = [3, 4, 23, 3]
106        elif n_layers == 152:
107            block = [3, 8, 36, 3]
108        else:
109            raise ValueError('The n_layers argument should be either 50, 101,'
110                             ' or 152, but {} was given.'.format(n_layers))
111
112        with self.init_scope():
113            self.conv1 = Convolution2D(3, 64, 7, 2, 3, **conv_kwargs)
114            self.bn1 = BatchNormalization(64)
115            self.res2 = BuildingBlock(block[0], 64, 64, 256, 1, **kwargs)
116            self.res3 = BuildingBlock(block[1], 256, 128, 512, 2, **kwargs)
117            self.res4 = BuildingBlock(block[2], 512, 256, 1024, 2, **kwargs)
118            self.res5 = BuildingBlock(block[3], 1024, 512, 2048, 2, **kwargs)
119            self.fc6 = Linear(2048, 1000)
120
121        if pretrained_model and pretrained_model.endswith('.caffemodel'):
122            _retrieve(n_layers, 'ResNet-{}-model.npz'.format(n_layers),
123                      pretrained_model, self)
124        elif pretrained_model:
125            npz.load_npz(pretrained_model, self)
126
127    @property
128    def functions(self):
129        return collections.OrderedDict([
130            ('conv1', [self.conv1, self.bn1, relu]),
131            ('pool1', [lambda x: max_pooling_2d(x, ksize=3, stride=2)]),
132            ('res2', [self.res2]),
133            ('res3', [self.res3]),
134            ('res4', [self.res4]),
135            ('res5', [self.res5]),
136            ('pool5', [_global_average_pooling_2d]),
137            ('fc6', [self.fc6]),
138            ('prob', [softmax]),
139        ])
140
141    @property
142    def available_layers(self):
143        return list(self.functions.keys())
144
145    @classmethod
146    def convert_caffemodel_to_npz(cls, path_caffemodel, path_npz, n_layers=50):
147        """Converts a pre-trained caffemodel to a chainer model.
148
149        Args:
150            path_caffemodel (str): Path of the pre-trained caffemodel.
151            path_npz (str): Path of the converted chainer model.
152        """
153
154        # As CaffeFunction uses shortcut symbols,
155        # we import CaffeFunction here.
156        from chainer.links.caffe.caffe_function import CaffeFunction
157        caffemodel = CaffeFunction(path_caffemodel)
158        chainermodel = cls(pretrained_model=None, n_layers=n_layers)
159        if n_layers == 50:
160            _transfer_resnet50(caffemodel, chainermodel)
161        elif n_layers == 101:
162            _transfer_resnet101(caffemodel, chainermodel)
163        elif n_layers == 152:
164            _transfer_resnet152(caffemodel, chainermodel)
165        else:
166            raise ValueError('The n_layers argument should be either 50, 101,'
167                             ' or 152, but {} was given.'.format(n_layers))
168        npz.save_npz(path_npz, chainermodel, compression=False)
169
170    def forward(self, x, layers=None, **kwargs):
171        """forward(self, x, layers=['prob'])
172
173        Computes all the feature maps specified by ``layers``.
174
175        Args:
176            x (~chainer.Variable): Input variable. It should be prepared by
177                ``prepare`` function.
178            layers (list of str): The list of layer names you want to extract.
179
180        Returns:
181            Dictionary of ~chainer.Variable: A directory in which
182            the key contains the layer name and the value contains
183            the corresponding feature map variable.
184
185        """
186
187        if layers is None:
188            layers = ['prob']
189
190        if kwargs:
191            argument.check_unexpected_kwargs(
192                kwargs, test='test argument is not supported anymore. '
193                'Use chainer.using_config')
194            argument.assert_kwargs_empty(kwargs)
195
196        h = x
197        activations = {}
198        target_layers = set(layers)
199        for key, funcs in self.functions.items():
200            if not target_layers:
201                break
202            for func in funcs:
203                h = func(h)
204            if key in target_layers:
205                activations[key] = h
206                target_layers.remove(key)
207        return activations
208
209    def extract(self, images, layers=None, size=(224, 224), **kwargs):
210        """extract(self, images, layers=['pool5'], size=(224, 224))
211
212        Extracts all the feature maps of given images.
213
214        The difference of directly executing ``forward`` is that
215        it directly accepts images as an input and automatically
216        transforms them to a proper variable. That is,
217        it is also interpreted as a shortcut method that implicitly calls
218        ``prepare`` and ``forward`` functions.
219
220        Unlike ``predict`` method, this method does not override
221        ``chainer.config.train`` and ``chainer.config.enable_backprop``
222        configuration. If you want to extract features without updating
223        model parameters, you need to manually set configuration when
224        calling this method as follows:
225
226         .. code-block:: python
227
228             # model is an instance of ResNetLayers (50 or 101 or 152 layers)
229             with chainer.using_config('train', False):
230                 with chainer.using_config('enable_backprop', False):
231                     feature = model.extract([image])
232
233        Args:
234            images (iterable of PIL.Image or numpy.ndarray): Input images.
235            layers (list of str): The list of layer names you want to extract.
236            size (pair of ints): The resolution of resized images used as
237                an input of CNN. All the given images are not resized
238                if this argument is ``None``, but the resolutions of
239                all the images should be the same.
240
241        Returns:
242            Dictionary of ~chainer.Variable: A directory in which
243            the key contains the layer name and the value contains
244            the corresponding feature map variable.
245
246        """
247
248        if layers is None:
249            layers = ['pool5']
250
251        if kwargs:
252            argument.check_unexpected_kwargs(
253                kwargs, test='test argument is not supported anymore. '
254                'Use chainer.using_config',
255                volatile='volatile argument is not supported anymore. '
256                'Use chainer.using_config')
257            argument.assert_kwargs_empty(kwargs)
258
259        x = concat_examples([prepare(img, size=size) for img in images])
260        x = Variable(self.xp.asarray(x))
261        return self(x, layers=layers)
262
263    def predict(self, images, oversample=True):
264        """Computes all the probabilities of given images.
265
266        Args:
267            images (iterable of PIL.Image or numpy.ndarray): Input images.
268                When you specify a color image as a :class:`numpy.ndarray`,
269                make sure that color order is RGB.
270            oversample (bool): If ``True``, it averages results across
271                center, corners, and mirrors. Otherwise, it uses only the
272                center.
273
274        Returns:
275            ~chainer.Variable: Output that contains the class probabilities
276            of given images.
277
278        """
279
280        x = concat_examples([prepare(img, size=(256, 256)) for img in images])
281        if oversample:
282            x = imgproc.oversample(x, crop_dims=(224, 224))
283        else:
284            x = x[:, :, 16:240, 16:240]
285        # Use no_backprop_mode to reduce memory consumption
286        with function.no_backprop_mode(), chainer.using_config('train', False):
287            x = Variable(self.xp.asarray(x))
288            y = self(x, layers=['prob'])['prob']
289            if oversample:
290                n = len(y) // 10
291                y_shape = y.shape[1:]
292                y = reshape(y, (n, 10) + y_shape)
293                y = sum(y, axis=1) / 10
294        return y
295
296
297class ResNet50Layers(ResNetLayers):
298
299    """A pre-trained CNN model with 50 layers provided by MSRA.
300
301    When you specify the path of the pre-trained chainer model serialized as
302    a ``.npz`` file in the constructor, this chain model automatically
303    initializes all the parameters with it.
304    This model would be useful when you want to extract a semantic feature
305    vector per image, or fine-tune the model on a different dataset.
306    Note that unlike ``VGG16Layers``, it does not automatically download a
307    pre-trained caffemodel. This caffemodel can be downloaded at
308    `GitHub <https://github.com/KaimingHe/deep-residual-networks>`_.
309
310    If you want to manually convert the pre-trained caffemodel to a chainer
311    model that can be specified in the constructor,
312    please use ``convert_caffemodel_to_npz`` classmethod instead.
313
314    ResNet50 has 25,557,096 trainable parameters, and it's 58% and 43% fewer
315    than ResNet101 and ResNet152, respectively. On the other hand, the top-5
316    classification accuracy on ImageNet dataset drops only 0.7% and 1.1% from
317    ResNet101 and ResNet152, respectively. Therefore, ResNet50 may have the
318    best balance between the accuracy and the model size. It would be basically
319    just enough for many cases, but some advanced models for object detection
320    or semantic segmentation use deeper ones as their building blocks, so these
321    deeper ResNets are here for making reproduction work easier.
322
323    See: K. He et. al., `Deep Residual Learning for Image Recognition
324    <https://arxiv.org/abs/1512.03385>`_
325
326    Args:
327        pretrained_model (str): the destination of the pre-trained
328            chainer model serialized as a ``.npz`` file.
329            If this argument is specified as ``auto``,
330            it automatically loads and converts the caffemodel from
331            ``$CHAINER_DATASET_ROOT/pfnet/chainer/models/ResNet-50-model.caffemodel``,
332            where ``$CHAINER_DATASET_ROOT`` is set as
333            ``$HOME/.chainer/dataset`` unless you specify another value
334            by modifying the environment variable. Note that in this case the
335            converted chainer model is stored on the same directory and
336            automatically used from the next time.
337            If this argument is specified as ``None``, all the parameters
338            are not initialized by the pre-trained model, but the default
339            initializer used in the original paper, i.e.,
340            ``chainer.initializers.HeNormal(scale=1.0)``.
341        downsample_fb (bool): If this argument is specified as ``False``,
342            it performs downsampling by placing stride 2
343            on the 1x1 convolutional layers (the original MSRA ResNet).
344            If this argument is specified as ``True``, it performs downsampling
345            by placing stride 2 on the 3x3 convolutional layers
346            (Facebook ResNet).
347
348    Attributes:
349        available_layers (list of str): The list of available layer names
350            used by ``forward`` and ``extract`` methods.
351
352    """
353
354    def __init__(self, pretrained_model='auto', downsample_fb=False):
355        if pretrained_model == 'auto':
356            pretrained_model = 'ResNet-50-model.caffemodel'
357        super(ResNet50Layers, self).__init__(
358            pretrained_model, 50, downsample_fb)
359
360
361class ResNet101Layers(ResNetLayers):
362
363    """A pre-trained CNN model with 101 layers provided by MSRA.
364
365    When you specify the path of the pre-trained chainer model serialized as
366    a ``.npz`` file in the constructor, this chain model automatically
367    initializes all the parameters with it.
368    This model would be useful when you want to extract a semantic feature
369    vector per image, or fine-tune the model on a different dataset.
370    Note that unlike ``VGG16Layers``, it does not automatically download a
371    pre-trained caffemodel. This caffemodel can be downloaded at
372    `GitHub <https://github.com/KaimingHe/deep-residual-networks>`_.
373
374    If you want to manually convert the pre-trained caffemodel to a chainer
375    model that can be specified in the constructor,
376    please use ``convert_caffemodel_to_npz`` classmethod instead.
377
378    ResNet101 has 44,549,224 trainable parameters, and it's 43% fewer than
379    ResNet152 model, while the top-5 classification accuracy on ImageNet
380    dataset drops 1.1% from ResNet152. For many cases, ResNet50 may have the
381    best balance between the accuracy and the model size.
382
383    See: K. He et. al., `Deep Residual Learning for Image Recognition
384    <https://arxiv.org/abs/1512.03385>`_
385
386    Args:
387        pretrained_model (str): the destination of the pre-trained
388            chainer model serialized as a ``.npz`` file.
389            If this argument is specified as ``auto``,
390            it automatically loads and converts the caffemodel from
391            ``$CHAINER_DATASET_ROOT/pfnet/chainer/models/ResNet-101-model.caffemodel``,
392            where ``$CHAINER_DATASET_ROOT`` is set as
393            ``$HOME/.chainer/dataset`` unless you specify another value
394            by modifying the environment variable. Note that in this case the
395            converted chainer model is stored on the same directory and
396            automatically used from the next time.
397            If this argument is specified as ``None``, all the parameters
398            are not initialized by the pre-trained model, but the default
399            initializer used in the original paper, i.e.,
400            ``chainer.initializers.HeNormal(scale=1.0)``.
401        downsample_fb (bool): If this argument is specified as ``False``,
402            it performs downsampling by placing stride 2
403            on the 1x1 convolutional layers (the original MSRA ResNet).
404            If this argument is specified as ``True``, it performs downsampling
405            by placing stride 2 on the 3x3 convolutional layers
406            (Facebook ResNet).
407
408    Attributes:
409        available_layers (list of str): The list of available layer names
410            used by ``forward`` and ``extract`` methods.
411
412    """
413
414    def __init__(self, pretrained_model='auto', downsample_fb=False):
415        if pretrained_model == 'auto':
416            pretrained_model = 'ResNet-101-model.caffemodel'
417        super(ResNet101Layers, self).__init__(
418            pretrained_model, 101, downsample_fb)
419
420
421class ResNet152Layers(ResNetLayers):
422
423    """A pre-trained CNN model with 152 layers provided by MSRA.
424
425    When you specify the path of the pre-trained chainer model serialized as
426    a ``.npz`` file in the constructor, this chain model automatically
427    initializes all the parameters with it.
428    This model would be useful when you want to extract a semantic feature
429    vector per image, or fine-tune the model on a different dataset.
430    Note that unlike ``VGG16Layers``, it does not automatically download a
431    pre-trained caffemodel. This caffemodel can be downloaded at
432    `GitHub <https://github.com/KaimingHe/deep-residual-networks>`_.
433
434    If you want to manually convert the pre-trained caffemodel to a chainer
435    model that can be specified in the constructor,
436    please use ``convert_caffemodel_to_npz`` classmethod instead.
437
438    ResNet152 has 60,192,872 trainable parameters, and it's the deepest ResNet
439    model and it achieves the best result on ImageNet classification task in
440    `ILSVRC 2015 <http://image-net.org/challenges/LSVRC/2015/results#loc>`_.
441
442    See: K. He et. al., `Deep Residual Learning for Image Recognition
443    <https://arxiv.org/abs/1512.03385>`_
444
445    Args:
446        pretrained_model (str): the destination of the pre-trained
447            chainer model serialized as a ``.npz`` file.
448            If this argument is specified as ``auto``,
449            it automatically loads and converts the caffemodel from
450            ``$CHAINER_DATASET_ROOT/pfnet/chainer/models/ResNet-152-model.caffemodel``,
451            where ``$CHAINER_DATASET_ROOT`` is set as
452            ``$HOME/.chainer/dataset`` unless you specify another value
453            by modifying the environment variable. Note that in this case the
454            converted chainer model is stored on the same directory and
455            automatically used from the next time.
456            If this argument is specified as ``None``, all the parameters
457            are not initialized by the pre-trained model, but the default
458            initializer used in the original paper, i.e.,
459            ``chainer.initializers.HeNormal(scale=1.0)``.
460        downsample_fb (bool): If this argument is specified as ``False``,
461            it performs downsampling by placing stride 2
462            on the 1x1 convolutional layers (the original MSRA ResNet).
463            If this argument is specified as ``True``, it performs downsampling
464            by placing stride 2 on the 3x3 convolutional layers
465            (Facebook ResNet).
466
467    Attributes:
468        available_layers (list of str): The list of available layer names
469            used by ``forward`` and ``extract`` methods.
470
471    """
472
473    def __init__(self, pretrained_model='auto', downsample_fb=False):
474        if pretrained_model == 'auto':
475            pretrained_model = 'ResNet-152-model.caffemodel'
476        super(ResNet152Layers, self).__init__(
477            pretrained_model, 152, downsample_fb)
478
479
480def prepare(image, size=(224, 224)):
481    """Converts the given image to a numpy array for ResNet.
482
483    Note that this method must be called before calling ``forward``,
484    because the pre-trained resnet model will resize the given
485    image, convert from RGB to BGR, subtract the mean,
486    and permute the dimensions before calling.
487
488    Args:
489        image (PIL.Image or numpy.ndarray): Input image.
490            If an input is ``numpy.ndarray``, its shape must be
491            ``(height, width)``, ``(height, width, channels)``,
492            or ``(channels, height, width)``, and
493            the order of the channels must be RGB.
494        size (pair of ints): Size of converted images.
495            If ``None``, the given image is not resized.
496
497    Returns:
498        numpy.ndarray: The converted output array.
499
500    """
501
502    if not available:
503        raise ImportError('PIL cannot be loaded. Install Pillow!\n'
504                          'The actual import error is as follows:\n' +
505                          str(_import_error))
506    dtype = chainer.get_dtype()
507    if isinstance(image, numpy.ndarray):
508        if image.ndim == 3:
509            if image.shape[0] == 1:
510                image = image[0, :, :]
511            elif image.shape[0] == 3:
512                image = image.transpose((1, 2, 0))
513        image = Image.fromarray(image.astype(numpy.uint8))
514    image = image.convert('RGB')
515    if size:
516        image = image.resize(size)
517    image = numpy.asarray(image, dtype=dtype)
518    image = image[:, :, ::-1]
519    # NOTE: in the original paper they subtract a fixed mean image,
520    #       however, in order to support arbitrary size we instead use the
521    #       mean pixel (rather than mean image) as with VGG team. The mean
522    #       value used in ResNet is slightly different from that of VGG16.
523    image -= numpy.array(
524        [103.063, 115.903, 123.152], dtype=dtype)
525    image = image.transpose((2, 0, 1))
526    return image
527
528
529class BuildingBlock(link.Chain):
530
531    """A building block that consists of several Bottleneck layers.
532
533    Args:
534        n_layer (int): *(deprecated since v7.0.0)*
535            `n_layer` is now deprecated for consistency of naming choice.
536            Please use `n_layers` instead.
537        n_layers (int): Number of layers used in the building block.
538        in_channels (int): Number of channels of input arrays.
539        mid_channels (int): Number of channels of intermediate arrays.
540        out_channels (int): Number of channels of output arrays.
541        stride (int or tuple of ints): Stride of filter application.
542        initialW (4-D array): Initial weight value used in
543            the convolutional layers.
544        downsample_fb (bool): If this argument is specified as ``False``,
545            it performs downsampling by placing stride 2
546            on the 1x1 convolutional layers (the original MSRA ResNet).
547            If this argument is specified as ``True``, it performs downsampling
548            by placing stride 2 on the 3x3 convolutional layers
549            (Facebook ResNet).
550
551    """
552
553    def __init__(self, n_layers=None, in_channels=None, mid_channels=None,
554                 out_channels=None, stride=None, initialW=None,
555                 downsample_fb=None, **kwargs):
556        super(BuildingBlock, self).__init__()
557
558        if 'n_layer' in kwargs:
559            warnings.warn(
560                'Argument `n_layer` is deprecated. '
561                'Please use `n_layers` instead',
562                DeprecationWarning)
563            n_layers = kwargs['n_layer']
564
565        with self.init_scope():
566            self.a = BottleneckA(
567                in_channels, mid_channels, out_channels, stride,
568                initialW, downsample_fb)
569            self._forward = ['a']
570            for i in range(n_layers - 1):
571                name = 'b{}'.format(i + 1)
572                bottleneck = BottleneckB(out_channels, mid_channels, initialW)
573                setattr(self, name, bottleneck)
574                self._forward.append(name)
575
576    def forward(self, x):
577        for name in self._forward:
578            l = getattr(self, name)
579            x = l(x)
580        return x
581
582
583class BottleneckA(link.Chain):
584
585    """A bottleneck layer that reduces the resolution of the feature map.
586
587    Args:
588        in_channels (int): Number of channels of input arrays.
589        mid_channels (int): Number of channels of intermediate arrays.
590        out_channels (int): Number of channels of output arrays.
591        stride (int or tuple of ints): Stride of filter application.
592        initialW (4-D array): Initial weight value used in
593            the convolutional layers.
594        downsample_fb (bool): If this argument is specified as ``False``,
595            it performs downsampling by placing stride 2
596            on the 1x1 convolutional layers (the original MSRA ResNet).
597            If this argument is specified as ``True``, it performs downsampling
598            by placing stride 2 on the 3x3 convolutional layers
599            (Facebook ResNet).
600    """
601
602    def __init__(self, in_channels, mid_channels, out_channels,
603                 stride=2, initialW=None, downsample_fb=False):
604        super(BottleneckA, self).__init__()
605        # In the original MSRA ResNet, stride=2 is on 1x1 convolution.
606        # In Facebook ResNet, stride=2 is on 3x3 convolution.
607
608        stride_1x1, stride_3x3 = (1, stride) if downsample_fb else (stride, 1)
609        with self.init_scope():
610            self.conv1 = Convolution2D(
611                in_channels, mid_channels, 1, stride_1x1, 0, initialW=initialW,
612                nobias=True)
613            self.bn1 = BatchNormalization(mid_channels)
614            self.conv2 = Convolution2D(
615                mid_channels, mid_channels, 3, stride_3x3, 1,
616                initialW=initialW, nobias=True)
617            self.bn2 = BatchNormalization(mid_channels)
618            self.conv3 = Convolution2D(
619                mid_channels, out_channels, 1, 1, 0, initialW=initialW,
620                nobias=True)
621            self.bn3 = BatchNormalization(out_channels)
622            self.conv4 = Convolution2D(
623                in_channels, out_channels, 1, stride, 0, initialW=initialW,
624                nobias=True)
625            self.bn4 = BatchNormalization(out_channels)
626
627    def forward(self, x):
628        h1 = relu(self.bn1(self.conv1(x)))
629        h1 = relu(self.bn2(self.conv2(h1)))
630        h1 = self.bn3(self.conv3(h1))
631        h2 = self.bn4(self.conv4(x))
632        return relu(h1 + h2)
633
634
635class BottleneckB(link.Chain):
636
637    """A bottleneck layer that maintains the resolution of the feature map.
638
639    Args:
640        in_channels (int): Number of channels of input and output arrays.
641        mid_channels (int): Number of channels of intermediate arrays.
642        initialW (4-D array): Initial weight value used in
643            the convolutional layers.
644    """
645
646    def __init__(self, in_channels, mid_channels, initialW=None):
647        super(BottleneckB, self).__init__()
648        with self.init_scope():
649            self.conv1 = Convolution2D(
650                in_channels, mid_channels, 1, 1, 0, initialW=initialW,
651                nobias=True)
652            self.bn1 = BatchNormalization(mid_channels)
653            self.conv2 = Convolution2D(
654                mid_channels, mid_channels, 3, 1, 1, initialW=initialW,
655                nobias=True)
656            self.bn2 = BatchNormalization(mid_channels)
657            self.conv3 = Convolution2D(
658                mid_channels, in_channels, 1, 1, 0, initialW=initialW,
659                nobias=True)
660            self.bn3 = BatchNormalization(in_channels)
661
662    def forward(self, x):
663        h = relu(self.bn1(self.conv1(x)))
664        h = relu(self.bn2(self.conv2(h)))
665        h = self.bn3(self.conv3(h))
666        return relu(h + x)
667
668
669def _global_average_pooling_2d(x):
670    n, channel, rows, cols = x.shape
671    h = average_pooling_2d(x, (rows, cols), stride=1)
672    h = reshape(h, (n, channel))
673    return h
674
675
676def _transfer_components(src, dst_conv, dst_bn, bname, cname):
677    src_conv = getattr(src, 'res{}_branch{}'.format(bname, cname))
678    src_bn = getattr(src, 'bn{}_branch{}'.format(bname, cname))
679    src_scale = getattr(src, 'scale{}_branch{}'.format(bname, cname))
680    dst_conv.W.array[:] = src_conv.W.array
681    dst_bn.avg_mean[:] = src_bn.avg_mean
682    dst_bn.avg_var[:] = src_bn.avg_var
683    dst_bn.gamma.array[:] = src_scale.W.array
684    dst_bn.beta.array[:] = src_scale.bias.b.array
685
686
687def _transfer_bottleneckA(src, dst, name):
688    _transfer_components(src, dst.conv1, dst.bn1, name, '2a')
689    _transfer_components(src, dst.conv2, dst.bn2, name, '2b')
690    _transfer_components(src, dst.conv3, dst.bn3, name, '2c')
691    _transfer_components(src, dst.conv4, dst.bn4, name, '1')
692
693
694def _transfer_bottleneckB(src, dst, name):
695    _transfer_components(src, dst.conv1, dst.bn1, name, '2a')
696    _transfer_components(src, dst.conv2, dst.bn2, name, '2b')
697    _transfer_components(src, dst.conv3, dst.bn3, name, '2c')
698
699
700def _transfer_block(src, dst, names):
701    _transfer_bottleneckA(src, dst.a, names[0])
702    for i, name in enumerate(names[1:]):
703        dst_bottleneckB = getattr(dst, 'b{}'.format(i + 1))
704        _transfer_bottleneckB(src, dst_bottleneckB, name)
705
706
707def _transfer_resnet50(src, dst):
708    dst.conv1.W.array[:] = src.conv1.W.array
709    dst.conv1.b.array[:] = src.conv1.b.array
710    dst.bn1.avg_mean[:] = src.bn_conv1.avg_mean
711    dst.bn1.avg_var[:] = src.bn_conv1.avg_var
712    dst.bn1.gamma.array[:] = src.scale_conv1.W.array
713    dst.bn1.beta.array[:] = src.scale_conv1.bias.b.array
714
715    _transfer_block(src, dst.res2, ['2a', '2b', '2c'])
716    _transfer_block(src, dst.res3, ['3a', '3b', '3c', '3d'])
717    _transfer_block(src, dst.res4, ['4a', '4b', '4c', '4d', '4e', '4f'])
718    _transfer_block(src, dst.res5, ['5a', '5b', '5c'])
719
720    dst.fc6.W.array[:] = src.fc1000.W.array
721    dst.fc6.b.array[:] = src.fc1000.b.array
722
723
724def _transfer_resnet101(src, dst):
725    dst.conv1.W.array[:] = src.conv1.W.array
726    dst.bn1.avg_mean[:] = src.bn_conv1.avg_mean
727    dst.bn1.avg_var[:] = src.bn_conv1.avg_var
728    dst.bn1.gamma.array[:] = src.scale_conv1.W.array
729    dst.bn1.beta.array[:] = src.scale_conv1.bias.b.array
730
731    _transfer_block(src, dst.res2, ['2a', '2b', '2c'])
732    _transfer_block(src, dst.res3, ['3a', '3b1', '3b2', '3b3'])
733    _transfer_block(src, dst.res4,
734                    ['4a'] + ['4b{}'.format(i) for i in range(1, 23)])
735    _transfer_block(src, dst.res5, ['5a', '5b', '5c'])
736
737    dst.fc6.W.array[:] = src.fc1000.W.array
738    dst.fc6.b.array[:] = src.fc1000.b.array
739
740
741def _transfer_resnet152(src, dst):
742    dst.conv1.W.array[:] = src.conv1.W.array
743    dst.bn1.avg_mean[:] = src.bn_conv1.avg_mean
744    dst.bn1.avg_var[:] = src.bn_conv1.avg_var
745    dst.bn1.gamma.array[:] = src.scale_conv1.W.array
746    dst.bn1.beta.array[:] = src.scale_conv1.bias.b.array
747
748    _transfer_block(src, dst.res2, ['2a', '2b', '2c'])
749    _transfer_block(src, dst.res3,
750                    ['3a'] + ['3b{}'.format(i) for i in range(1, 8)])
751    _transfer_block(src, dst.res4,
752                    ['4a'] + ['4b{}'.format(i) for i in range(1, 36)])
753    _transfer_block(src, dst.res5, ['5a', '5b', '5c'])
754
755    dst.fc6.W.array[:] = src.fc1000.W.array
756    dst.fc6.b.array[:] = src.fc1000.b.array
757
758
759def _make_npz(path_npz, path_caffemodel, model, n_layers):
760    sys.stderr.write(
761        'Now loading caffemodel (usually it may take few minutes)\n')
762    sys.stderr.flush()
763    if not os.path.exists(path_caffemodel):
764        raise IOError(
765            'The pre-trained caffemodel does not exist. Please download it '
766            'from \'https://github.com/KaimingHe/deep-residual-networks\', '
767            'and place it on {}'.format(path_caffemodel))
768    ResNetLayers.convert_caffemodel_to_npz(path_caffemodel, path_npz, n_layers)
769    npz.load_npz(path_npz, model)
770    return model
771
772
773def _retrieve(n_layers, name_npz, name_caffemodel, model):
774    root = download.get_dataset_directory('pfnet/chainer/models/')
775    path = os.path.join(root, name_npz)
776    path_caffemodel = os.path.join(root, name_caffemodel)
777    return download.cache_or_load_file(
778        path, lambda path: _make_npz(path, path_caffemodel, model, n_layers),
779        lambda path: npz.load_npz(path, model))
780