1"""ResNet, ResNetV2, and ResNeXt models for Keras.
2
3# Reference papers
4
5- [Deep Residual Learning for Image Recognition]
6  (https://arxiv.org/abs/1512.03385) (CVPR 2016 Best Paper Award)
7- [Identity Mappings in Deep Residual Networks]
8  (https://arxiv.org/abs/1603.05027) (ECCV 2016)
9- [Aggregated Residual Transformations for Deep Neural Networks]
10  (https://arxiv.org/abs/1611.05431) (CVPR 2017)
11
12# Reference implementations
13
14- [TensorNets]
15  (https://github.com/taehoonlee/tensornets/blob/master/tensornets/resnets.py)
16- [Caffe ResNet]
17  (https://github.com/KaimingHe/deep-residual-networks/tree/master/prototxt)
18- [Torch ResNetV2]
19  (https://github.com/facebook/fb.resnet.torch/blob/master/models/preresnet.lua)
20- [Torch ResNeXt]
21  (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
22
23"""
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import os
29
30from . import get_submodules_from_kwargs
31from .imagenet_utils import _obtain_input_shape
32
33
34backend = None
35layers = None
36models = None
37keras_utils = None
38
39
40BASE_WEIGHTS_PATH = (
41    'https://github.com/keras-team/keras-applications/'
42    'releases/download/resnet/')
43WEIGHTS_HASHES = {
44    'resnet50': ('2cb95161c43110f7111970584f804107',
45                 '4d473c1dd8becc155b73f8504c6f6626'),
46    'resnet101': ('f1aeb4b969a6efcfb50fad2f0c20cfc5',
47                  '88cf7a10940856eca736dc7b7e228a21'),
48    'resnet152': ('100835be76be38e30d865e96f2aaae62',
49                  'ee4c566cf9a93f14d82f913c2dc6dd0c'),
50    'resnet50v2': ('3ef43a0b657b3be2300d5770ece849e0',
51                   'fac2f116257151a9d068a22e544a4917'),
52    'resnet101v2': ('6343647c601c52e1368623803854d971',
53                    'c0ed64b8031c3730f411d2eb4eea35b5'),
54    'resnet152v2': ('a49b44d1979771252814e80f8ec446f9',
55                    'ed17cf2e0169df9d443503ef94b23b33'),
56    'resnext50': ('67a5b30d522ed92f75a1f16eef299d1a',
57                  '62527c363bdd9ec598bed41947b379fc'),
58    'resnext101': ('34fb605428fcc7aa4d62f44404c11509',
59                   '0f678c91647380debd923963594981b3')
60}
61
62
63def block1(x, filters, kernel_size=3, stride=1,
64           conv_shortcut=True, name=None):
65    """A residual block.
66
67    # Arguments
68        x: input tensor.
69        filters: integer, filters of the bottleneck layer.
70        kernel_size: default 3, kernel size of the bottleneck layer.
71        stride: default 1, stride of the first layer.
72        conv_shortcut: default True, use convolution shortcut if True,
73            otherwise identity shortcut.
74        name: string, block label.
75
76    # Returns
77        Output tensor for the residual block.
78    """
79    bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
80
81    if conv_shortcut is True:
82        shortcut = layers.Conv2D(4 * filters, 1, strides=stride,
83                                 name=name + '_0_conv')(x)
84        shortcut = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
85                                             name=name + '_0_bn')(shortcut)
86    else:
87        shortcut = x
88
89    x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(x)
90    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
91                                  name=name + '_1_bn')(x)
92    x = layers.Activation('relu', name=name + '_1_relu')(x)
93
94    x = layers.Conv2D(filters, kernel_size, padding='SAME',
95                      name=name + '_2_conv')(x)
96    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
97                                  name=name + '_2_bn')(x)
98    x = layers.Activation('relu', name=name + '_2_relu')(x)
99
100    x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
101    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
102                                  name=name + '_3_bn')(x)
103
104    x = layers.Add(name=name + '_add')([shortcut, x])
105    x = layers.Activation('relu', name=name + '_out')(x)
106    return x
107
108
109def stack1(x, filters, blocks, stride1=2, name=None):
110    """A set of stacked residual blocks.
111
112    # Arguments
113        x: input tensor.
114        filters: integer, filters of the bottleneck layer in a block.
115        blocks: integer, blocks in the stacked blocks.
116        stride1: default 2, stride of the first layer in the first block.
117        name: string, stack label.
118
119    # Returns
120        Output tensor for the stacked blocks.
121    """
122    x = block1(x, filters, stride=stride1, name=name + '_block1')
123    for i in range(2, blocks + 1):
124        x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))
125    return x
126
127
128def block2(x, filters, kernel_size=3, stride=1,
129           conv_shortcut=False, name=None):
130    """A residual block.
131
132    # Arguments
133        x: input tensor.
134        filters: integer, filters of the bottleneck layer.
135        kernel_size: default 3, kernel size of the bottleneck layer.
136        stride: default 1, stride of the first layer.
137        conv_shortcut: default False, use convolution shortcut if True,
138            otherwise identity shortcut.
139        name: string, block label.
140
141    # Returns
142        Output tensor for the residual block.
143    """
144    bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
145
146    preact = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
147                                       name=name + '_preact_bn')(x)
148    preact = layers.Activation('relu', name=name + '_preact_relu')(preact)
149
150    if conv_shortcut is True:
151        shortcut = layers.Conv2D(4 * filters, 1, strides=stride,
152                                 name=name + '_0_conv')(preact)
153    else:
154        shortcut = layers.MaxPooling2D(1, strides=stride)(x) if stride > 1 else x
155
156    x = layers.Conv2D(filters, 1, strides=1, use_bias=False,
157                      name=name + '_1_conv')(preact)
158    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
159                                  name=name + '_1_bn')(x)
160    x = layers.Activation('relu', name=name + '_1_relu')(x)
161
162    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=name + '_2_pad')(x)
163    x = layers.Conv2D(filters, kernel_size, strides=stride,
164                      use_bias=False, name=name + '_2_conv')(x)
165    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
166                                  name=name + '_2_bn')(x)
167    x = layers.Activation('relu', name=name + '_2_relu')(x)
168
169    x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
170    x = layers.Add(name=name + '_out')([shortcut, x])
171    return x
172
173
174def stack2(x, filters, blocks, stride1=2, name=None):
175    """A set of stacked residual blocks.
176
177    # Arguments
178        x: input tensor.
179        filters: integer, filters of the bottleneck layer in a block.
180        blocks: integer, blocks in the stacked blocks.
181        stride1: default 2, stride of the first layer in the first block.
182        name: string, stack label.
183
184    # Returns
185        Output tensor for the stacked blocks.
186    """
187    x = block2(x, filters, conv_shortcut=True, name=name + '_block1')
188    for i in range(2, blocks):
189        x = block2(x, filters, name=name + '_block' + str(i))
190    x = block2(x, filters, stride=stride1, name=name + '_block' + str(blocks))
191    return x
192
193
194def block3(x, filters, kernel_size=3, stride=1, groups=32,
195           conv_shortcut=True, name=None):
196    """A residual block.
197
198    # Arguments
199        x: input tensor.
200        filters: integer, filters of the bottleneck layer.
201        kernel_size: default 3, kernel size of the bottleneck layer.
202        stride: default 1, stride of the first layer.
203        groups: default 32, group size for grouped convolution.
204        conv_shortcut: default True, use convolution shortcut if True,
205            otherwise identity shortcut.
206        name: string, block label.
207
208    # Returns
209        Output tensor for the residual block.
210    """
211    bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
212
213    if conv_shortcut is True:
214        shortcut = layers.Conv2D((64 // groups) * filters, 1, strides=stride,
215                                 use_bias=False, name=name + '_0_conv')(x)
216        shortcut = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
217                                             name=name + '_0_bn')(shortcut)
218    else:
219        shortcut = x
220
221    x = layers.Conv2D(filters, 1, use_bias=False, name=name + '_1_conv')(x)
222    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
223                                  name=name + '_1_bn')(x)
224    x = layers.Activation('relu', name=name + '_1_relu')(x)
225
226    c = filters // groups
227    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=name + '_2_pad')(x)
228    x = layers.DepthwiseConv2D(kernel_size, strides=stride, depth_multiplier=c,
229                               use_bias=False, name=name + '_2_conv')(x)
230    x_shape = backend.int_shape(x)[1:-1]
231    x = layers.Reshape(x_shape + (groups, c, c))(x)
232    output_shape = x_shape + (groups, c) if backend.backend() == 'theano' else None
233    x = layers.Lambda(lambda x: sum([x[:, :, :, :, i] for i in range(c)]),
234                      output_shape=output_shape, name=name + '_2_reduce')(x)
235    x = layers.Reshape(x_shape + (filters,))(x)
236    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
237                                  name=name + '_2_bn')(x)
238    x = layers.Activation('relu', name=name + '_2_relu')(x)
239
240    x = layers.Conv2D((64 // groups) * filters, 1,
241                      use_bias=False, name=name + '_3_conv')(x)
242    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
243                                  name=name + '_3_bn')(x)
244
245    x = layers.Add(name=name + '_add')([shortcut, x])
246    x = layers.Activation('relu', name=name + '_out')(x)
247    return x
248
249
250def stack3(x, filters, blocks, stride1=2, groups=32, name=None):
251    """A set of stacked residual blocks.
252
253    # Arguments
254        x: input tensor.
255        filters: integer, filters of the bottleneck layer in a block.
256        blocks: integer, blocks in the stacked blocks.
257        stride1: default 2, stride of the first layer in the first block.
258        groups: default 32, group size for grouped convolution.
259        name: string, stack label.
260
261    # Returns
262        Output tensor for the stacked blocks.
263    """
264    x = block3(x, filters, stride=stride1, groups=groups, name=name + '_block1')
265    for i in range(2, blocks + 1):
266        x = block3(x, filters, groups=groups, conv_shortcut=False,
267                   name=name + '_block' + str(i))
268    return x
269
270
271def ResNet(stack_fn,
272           preact,
273           use_bias,
274           model_name='resnet',
275           include_top=True,
276           weights='imagenet',
277           input_tensor=None,
278           input_shape=None,
279           pooling=None,
280           classes=1000,
281           **kwargs):
282    """Instantiates the ResNet, ResNetV2, and ResNeXt architecture.
283
284    Optionally loads weights pre-trained on ImageNet.
285    Note that the data format convention used by the model is
286    the one specified in your Keras config at `~/.keras/keras.json`.
287
288    # Arguments
289        stack_fn: a function that returns output tensor for the
290            stacked residual blocks.
291        preact: whether to use pre-activation or not
292            (True for ResNetV2, False for ResNet and ResNeXt).
293        use_bias: whether to use biases for convolutional layers or not
294            (True for ResNet and ResNetV2, False for ResNeXt).
295        model_name: string, model name.
296        include_top: whether to include the fully-connected
297            layer at the top of the network.
298        weights: one of `None` (random initialization),
299              'imagenet' (pre-training on ImageNet),
300              or the path to the weights file to be loaded.
301        input_tensor: optional Keras tensor
302            (i.e. output of `layers.Input()`)
303            to use as image input for the model.
304        input_shape: optional shape tuple, only to be specified
305            if `include_top` is False (otherwise the input shape
306            has to be `(224, 224, 3)` (with `channels_last` data format)
307            or `(3, 224, 224)` (with `channels_first` data format).
308            It should have exactly 3 inputs channels.
309        pooling: optional pooling mode for feature extraction
310            when `include_top` is `False`.
311            - `None` means that the output of the model will be
312                the 4D tensor output of the
313                last convolutional layer.
314            - `avg` means that global average pooling
315                will be applied to the output of the
316                last convolutional layer, and thus
317                the output of the model will be a 2D tensor.
318            - `max` means that global max pooling will
319                be applied.
320        classes: optional number of classes to classify images
321            into, only to be specified if `include_top` is True, and
322            if no `weights` argument is specified.
323
324    # Returns
325        A Keras model instance.
326
327    # Raises
328        ValueError: in case of invalid argument for `weights`,
329            or invalid input shape.
330    """
331    global backend, layers, models, keras_utils
332    backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
333
334    if not (weights in {'imagenet', None} or os.path.exists(weights)):
335        raise ValueError('The `weights` argument should be either '
336                         '`None` (random initialization), `imagenet` '
337                         '(pre-training on ImageNet), '
338                         'or the path to the weights file to be loaded.')
339
340    if weights == 'imagenet' and include_top and classes != 1000:
341        raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
342                         ' as true, `classes` should be 1000')
343
344    # Determine proper input shape
345    input_shape = _obtain_input_shape(input_shape,
346                                      default_size=224,
347                                      min_size=32,
348                                      data_format=backend.image_data_format(),
349                                      require_flatten=include_top,
350                                      weights=weights)
351
352    if input_tensor is None:
353        img_input = layers.Input(shape=input_shape)
354    else:
355        if not backend.is_keras_tensor(input_tensor):
356            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
357        else:
358            img_input = input_tensor
359
360    bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
361
362    x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name='conv1_pad')(img_input)
363    x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)
364
365    if preact is False:
366        x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
367                                      name='conv1_bn')(x)
368        x = layers.Activation('relu', name='conv1_relu')(x)
369
370    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
371    x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)
372
373    x = stack_fn(x)
374
375    if preact is True:
376        x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
377                                      name='post_bn')(x)
378        x = layers.Activation('relu', name='post_relu')(x)
379
380    if include_top:
381        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
382        x = layers.Dense(classes, activation='softmax', name='probs')(x)
383    else:
384        if pooling == 'avg':
385            x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
386        elif pooling == 'max':
387            x = layers.GlobalMaxPooling2D(name='max_pool')(x)
388
389    # Ensure that the model takes into account
390    # any potential predecessors of `input_tensor`.
391    if input_tensor is not None:
392        inputs = keras_utils.get_source_inputs(input_tensor)
393    else:
394        inputs = img_input
395
396    # Create model.
397    model = models.Model(inputs, x, name=model_name)
398
399    # Load weights.
400    if (weights == 'imagenet') and (model_name in WEIGHTS_HASHES):
401        if include_top:
402            file_name = model_name + '_weights_tf_dim_ordering_tf_kernels.h5'
403            file_hash = WEIGHTS_HASHES[model_name][0]
404        else:
405            file_name = model_name + '_weights_tf_dim_ordering_tf_kernels_notop.h5'
406            file_hash = WEIGHTS_HASHES[model_name][1]
407        weights_path = keras_utils.get_file(file_name,
408                                            BASE_WEIGHTS_PATH + file_name,
409                                            cache_subdir='models',
410                                            file_hash=file_hash)
411        model.load_weights(weights_path)
412    elif weights is not None:
413        model.load_weights(weights)
414
415    return model
416
417
418def ResNet50(include_top=True,
419             weights='imagenet',
420             input_tensor=None,
421             input_shape=None,
422             pooling=None,
423             classes=1000,
424             **kwargs):
425    def stack_fn(x):
426        x = stack1(x, 64, 3, stride1=1, name='conv2')
427        x = stack1(x, 128, 4, name='conv3')
428        x = stack1(x, 256, 6, name='conv4')
429        x = stack1(x, 512, 3, name='conv5')
430        return x
431    return ResNet(stack_fn, False, True, 'resnet50',
432                  include_top, weights,
433                  input_tensor, input_shape,
434                  pooling, classes,
435                  **kwargs)
436
437
438def ResNet101(include_top=True,
439              weights='imagenet',
440              input_tensor=None,
441              input_shape=None,
442              pooling=None,
443              classes=1000,
444              **kwargs):
445    def stack_fn(x):
446        x = stack1(x, 64, 3, stride1=1, name='conv2')
447        x = stack1(x, 128, 4, name='conv3')
448        x = stack1(x, 256, 23, name='conv4')
449        x = stack1(x, 512, 3, name='conv5')
450        return x
451    return ResNet(stack_fn, False, True, 'resnet101',
452                  include_top, weights,
453                  input_tensor, input_shape,
454                  pooling, classes,
455                  **kwargs)
456
457
458def ResNet152(include_top=True,
459              weights='imagenet',
460              input_tensor=None,
461              input_shape=None,
462              pooling=None,
463              classes=1000,
464              **kwargs):
465    def stack_fn(x):
466        x = stack1(x, 64, 3, stride1=1, name='conv2')
467        x = stack1(x, 128, 8, name='conv3')
468        x = stack1(x, 256, 36, name='conv4')
469        x = stack1(x, 512, 3, name='conv5')
470        return x
471    return ResNet(stack_fn, False, True, 'resnet152',
472                  include_top, weights,
473                  input_tensor, input_shape,
474                  pooling, classes,
475                  **kwargs)
476
477
478def ResNet50V2(include_top=True,
479               weights='imagenet',
480               input_tensor=None,
481               input_shape=None,
482               pooling=None,
483               classes=1000,
484               **kwargs):
485    def stack_fn(x):
486        x = stack2(x, 64, 3, name='conv2')
487        x = stack2(x, 128, 4, name='conv3')
488        x = stack2(x, 256, 6, name='conv4')
489        x = stack2(x, 512, 3, stride1=1, name='conv5')
490        return x
491    return ResNet(stack_fn, True, True, 'resnet50v2',
492                  include_top, weights,
493                  input_tensor, input_shape,
494                  pooling, classes,
495                  **kwargs)
496
497
498def ResNet101V2(include_top=True,
499                weights='imagenet',
500                input_tensor=None,
501                input_shape=None,
502                pooling=None,
503                classes=1000,
504                **kwargs):
505    def stack_fn(x):
506        x = stack2(x, 64, 3, name='conv2')
507        x = stack2(x, 128, 4, name='conv3')
508        x = stack2(x, 256, 23, name='conv4')
509        x = stack2(x, 512, 3, stride1=1, name='conv5')
510        return x
511    return ResNet(stack_fn, True, True, 'resnet101v2',
512                  include_top, weights,
513                  input_tensor, input_shape,
514                  pooling, classes,
515                  **kwargs)
516
517
518def ResNet152V2(include_top=True,
519                weights='imagenet',
520                input_tensor=None,
521                input_shape=None,
522                pooling=None,
523                classes=1000,
524                **kwargs):
525    def stack_fn(x):
526        x = stack2(x, 64, 3, name='conv2')
527        x = stack2(x, 128, 8, name='conv3')
528        x = stack2(x, 256, 36, name='conv4')
529        x = stack2(x, 512, 3, stride1=1, name='conv5')
530        return x
531    return ResNet(stack_fn, True, True, 'resnet152v2',
532                  include_top, weights,
533                  input_tensor, input_shape,
534                  pooling, classes,
535                  **kwargs)
536
537
538def ResNeXt50(include_top=True,
539              weights='imagenet',
540              input_tensor=None,
541              input_shape=None,
542              pooling=None,
543              classes=1000,
544              **kwargs):
545    def stack_fn(x):
546        x = stack3(x, 128, 3, stride1=1, name='conv2')
547        x = stack3(x, 256, 4, name='conv3')
548        x = stack3(x, 512, 6, name='conv4')
549        x = stack3(x, 1024, 3, name='conv5')
550        return x
551    return ResNet(stack_fn, False, False, 'resnext50',
552                  include_top, weights,
553                  input_tensor, input_shape,
554                  pooling, classes,
555                  **kwargs)
556
557
558def ResNeXt101(include_top=True,
559               weights='imagenet',
560               input_tensor=None,
561               input_shape=None,
562               pooling=None,
563               classes=1000,
564               **kwargs):
565    def stack_fn(x):
566        x = stack3(x, 128, 3, stride1=1, name='conv2')
567        x = stack3(x, 256, 4, name='conv3')
568        x = stack3(x, 512, 23, name='conv4')
569        x = stack3(x, 1024, 3, name='conv5')
570        return x
571    return ResNet(stack_fn, False, False, 'resnext101',
572                  include_top, weights,
573                  input_tensor, input_shape,
574                  pooling, classes,
575                  **kwargs)
576
577
578setattr(ResNet50, '__doc__', ResNet.__doc__)
579setattr(ResNet101, '__doc__', ResNet.__doc__)
580setattr(ResNet152, '__doc__', ResNet.__doc__)
581setattr(ResNet50V2, '__doc__', ResNet.__doc__)
582setattr(ResNet101V2, '__doc__', ResNet.__doc__)
583setattr(ResNet152V2, '__doc__', ResNet.__doc__)
584setattr(ResNeXt50, '__doc__', ResNet.__doc__)
585setattr(ResNeXt101, '__doc__', ResNet.__doc__)
586