1"""ResNet50 model for Keras.
2
3# Reference:
4
5- [Deep Residual Learning for Image Recognition](
6    https://arxiv.org/abs/1512.03385) (CVPR 2016 Best Paper Award)
7
8Adapted from code contributed by BigMoyan.
9"""
10from __future__ import absolute_import
11from __future__ import division
12from __future__ import print_function
13
14import os
15import warnings
16
17from . import get_submodules_from_kwargs
18from . import imagenet_utils
19from .imagenet_utils import decode_predictions
20from .imagenet_utils import _obtain_input_shape
21
22preprocess_input = imagenet_utils.preprocess_input
23
24WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
25                'releases/download/v0.2/'
26                'resnet50_weights_tf_dim_ordering_tf_kernels.h5')
27WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/'
28                       'releases/download/v0.2/'
29                       'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
30
31backend = None
32layers = None
33models = None
34keras_utils = None
35
36
37def identity_block(input_tensor, kernel_size, filters, stage, block):
38    """The identity block is the block that has no conv layer at shortcut.
39
40    # Arguments
41        input_tensor: input tensor
42        kernel_size: default 3, the kernel size of
43            middle conv layer at main path
44        filters: list of integers, the filters of 3 conv layer at main path
45        stage: integer, current stage label, used for generating layer names
46        block: 'a','b'..., current block label, used for generating layer names
47
48    # Returns
49        Output tensor for the block.
50    """
51    filters1, filters2, filters3 = filters
52    if backend.image_data_format() == 'channels_last':
53        bn_axis = 3
54    else:
55        bn_axis = 1
56    conv_name_base = 'res' + str(stage) + block + '_branch'
57    bn_name_base = 'bn' + str(stage) + block + '_branch'
58
59    x = layers.Conv2D(filters1, (1, 1),
60                      kernel_initializer='he_normal',
61                      name=conv_name_base + '2a')(input_tensor)
62    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
63    x = layers.Activation('relu')(x)
64
65    x = layers.Conv2D(filters2, kernel_size,
66                      padding='same',
67                      kernel_initializer='he_normal',
68                      name=conv_name_base + '2b')(x)
69    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
70    x = layers.Activation('relu')(x)
71
72    x = layers.Conv2D(filters3, (1, 1),
73                      kernel_initializer='he_normal',
74                      name=conv_name_base + '2c')(x)
75    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
76
77    x = layers.add([x, input_tensor])
78    x = layers.Activation('relu')(x)
79    return x
80
81
82def conv_block(input_tensor,
83               kernel_size,
84               filters,
85               stage,
86               block,
87               strides=(2, 2)):
88    """A block that has a conv layer at shortcut.
89
90    # Arguments
91        input_tensor: input tensor
92        kernel_size: default 3, the kernel size of
93            middle conv layer at main path
94        filters: list of integers, the filters of 3 conv layer at main path
95        stage: integer, current stage label, used for generating layer names
96        block: 'a','b'..., current block label, used for generating layer names
97        strides: Strides for the first conv layer in the block.
98
99    # Returns
100        Output tensor for the block.
101
102    Note that from stage 3,
103    the first conv layer at main path is with strides=(2, 2)
104    And the shortcut should have strides=(2, 2) as well
105    """
106    filters1, filters2, filters3 = filters
107    if backend.image_data_format() == 'channels_last':
108        bn_axis = 3
109    else:
110        bn_axis = 1
111    conv_name_base = 'res' + str(stage) + block + '_branch'
112    bn_name_base = 'bn' + str(stage) + block + '_branch'
113
114    x = layers.Conv2D(filters1, (1, 1), strides=strides,
115                      kernel_initializer='he_normal',
116                      name=conv_name_base + '2a')(input_tensor)
117    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
118    x = layers.Activation('relu')(x)
119
120    x = layers.Conv2D(filters2, kernel_size, padding='same',
121                      kernel_initializer='he_normal',
122                      name=conv_name_base + '2b')(x)
123    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
124    x = layers.Activation('relu')(x)
125
126    x = layers.Conv2D(filters3, (1, 1),
127                      kernel_initializer='he_normal',
128                      name=conv_name_base + '2c')(x)
129    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
130
131    shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
132                             kernel_initializer='he_normal',
133                             name=conv_name_base + '1')(input_tensor)
134    shortcut = layers.BatchNormalization(
135        axis=bn_axis, name=bn_name_base + '1')(shortcut)
136
137    x = layers.add([x, shortcut])
138    x = layers.Activation('relu')(x)
139    return x
140
141
142def ResNet50(include_top=True,
143             weights='imagenet',
144             input_tensor=None,
145             input_shape=None,
146             pooling=None,
147             classes=1000,
148             **kwargs):
149    """Instantiates the ResNet50 architecture.
150
151    Optionally loads weights pre-trained on ImageNet.
152    Note that the data format convention used by the model is
153    the one specified in your Keras config at `~/.keras/keras.json`.
154
155    # Arguments
156        include_top: whether to include the fully-connected
157            layer at the top of the network.
158        weights: one of `None` (random initialization),
159              'imagenet' (pre-training on ImageNet),
160              or the path to the weights file to be loaded.
161        input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
162            to use as image input for the model.
163        input_shape: optional shape tuple, only to be specified
164            if `include_top` is False (otherwise the input shape
165            has to be `(224, 224, 3)` (with `channels_last` data format)
166            or `(3, 224, 224)` (with `channels_first` data format).
167            It should have exactly 3 inputs channels,
168            and width and height should be no smaller than 32.
169            E.g. `(200, 200, 3)` would be one valid value.
170        pooling: Optional pooling mode for feature extraction
171            when `include_top` is `False`.
172            - `None` means that the output of the model will be
173                the 4D tensor output of the
174                last convolutional block.
175            - `avg` means that global average pooling
176                will be applied to the output of the
177                last convolutional block, and thus
178                the output of the model will be a 2D tensor.
179            - `max` means that global max pooling will
180                be applied.
181        classes: optional number of classes to classify images
182            into, only to be specified if `include_top` is True, and
183            if no `weights` argument is specified.
184
185    # Returns
186        A Keras model instance.
187
188    # Raises
189        ValueError: in case of invalid argument for `weights`,
190            or invalid input shape.
191    """
192    global backend, layers, models, keras_utils
193    backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
194
195    if not (weights in {'imagenet', None} or os.path.exists(weights)):
196        raise ValueError('The `weights` argument should be either '
197                         '`None` (random initialization), `imagenet` '
198                         '(pre-training on ImageNet), '
199                         'or the path to the weights file to be loaded.')
200
201    if weights == 'imagenet' and include_top and classes != 1000:
202        raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
203                         ' as true, `classes` should be 1000')
204
205    # Determine proper input shape
206    input_shape = _obtain_input_shape(input_shape,
207                                      default_size=224,
208                                      min_size=32,
209                                      data_format=backend.image_data_format(),
210                                      require_flatten=include_top,
211                                      weights=weights)
212
213    if input_tensor is None:
214        img_input = layers.Input(shape=input_shape)
215    else:
216        if not backend.is_keras_tensor(input_tensor):
217            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
218        else:
219            img_input = input_tensor
220    if backend.image_data_format() == 'channels_last':
221        bn_axis = 3
222    else:
223        bn_axis = 1
224
225    x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
226    x = layers.Conv2D(64, (7, 7),
227                      strides=(2, 2),
228                      padding='valid',
229                      kernel_initializer='he_normal',
230                      name='conv1')(x)
231    x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
232    x = layers.Activation('relu')(x)
233    x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
234    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
235
236    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
237    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
238    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
239
240    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
241    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
242    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
243    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
244
245    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
246    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
247    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
248    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
249    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
250    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
251
252    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
253    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
254    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
255
256    if include_top:
257        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
258        x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
259    else:
260        if pooling == 'avg':
261            x = layers.GlobalAveragePooling2D()(x)
262        elif pooling == 'max':
263            x = layers.GlobalMaxPooling2D()(x)
264        else:
265            warnings.warn('The output shape of `ResNet50(include_top=False)` '
266                          'has been changed since Keras 2.2.0.')
267
268    # Ensure that the model takes into account
269    # any potential predecessors of `input_tensor`.
270    if input_tensor is not None:
271        inputs = keras_utils.get_source_inputs(input_tensor)
272    else:
273        inputs = img_input
274    # Create model.
275    model = models.Model(inputs, x, name='resnet50')
276
277    # Load weights.
278    if weights == 'imagenet':
279        if include_top:
280            weights_path = keras_utils.get_file(
281                'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
282                WEIGHTS_PATH,
283                cache_subdir='models',
284                md5_hash='a7b3fe01876f51b976af0dea6bc144eb')
285        else:
286            weights_path = keras_utils.get_file(
287                'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
288                WEIGHTS_PATH_NO_TOP,
289                cache_subdir='models',
290                md5_hash='a268eb855778b3df3c7506639542a6af')
291        model.load_weights(weights_path)
292        if backend.backend() == 'theano':
293            keras_utils.convert_all_kernels_in_model(model)
294    elif weights is not None:
295        model.load_weights(weights)
296
297    return model
298