1"""Inception-ResNet V2 model for Keras.
2
3Model naming and structure follows TF-slim implementation
4(which has some additional layers and different number of
5filters from the original arXiv paper):
6https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py
7
8Pre-trained ImageNet weights are also converted from TF-slim,
9which can be found in:
10https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models
11
12# Reference
13- [Inception-v4, Inception-ResNet and the Impact of
14   Residual Connections on Learning](https://arxiv.org/abs/1602.07261) (AAAI 2017)
15
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from . import get_submodules_from_kwargs
24from . import imagenet_utils
25from .imagenet_utils import decode_predictions
26from .imagenet_utils import _obtain_input_shape
27
28
29BASE_WEIGHT_URL = ('https://github.com/fchollet/deep-learning-models/'
30                   'releases/download/v0.7/')
31
32backend = None
33layers = None
34models = None
35keras_utils = None
36
37
38def preprocess_input(x, **kwargs):
39    """Preprocesses a numpy array encoding a batch of images.
40
41    # Arguments
42        x: a 4D numpy array consists of RGB values within [0, 255].
43
44    # Returns
45        Preprocessed array.
46    """
47    return imagenet_utils.preprocess_input(x, mode='tf', **kwargs)
48
49
50def conv2d_bn(x,
51              filters,
52              kernel_size,
53              strides=1,
54              padding='same',
55              activation='relu',
56              use_bias=False,
57              name=None):
58    """Utility function to apply conv + BN.
59
60    # Arguments
61        x: input tensor.
62        filters: filters in `Conv2D`.
63        kernel_size: kernel size as in `Conv2D`.
64        strides: strides in `Conv2D`.
65        padding: padding mode in `Conv2D`.
66        activation: activation in `Conv2D`.
67        use_bias: whether to use a bias in `Conv2D`.
68        name: name of the ops; will become `name + '_ac'` for the activation
69            and `name + '_bn'` for the batch norm layer.
70
71    # Returns
72        Output tensor after applying `Conv2D` and `BatchNormalization`.
73    """
74    x = layers.Conv2D(filters,
75                      kernel_size,
76                      strides=strides,
77                      padding=padding,
78                      use_bias=use_bias,
79                      name=name)(x)
80    if not use_bias:
81        bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3
82        bn_name = None if name is None else name + '_bn'
83        x = layers.BatchNormalization(axis=bn_axis,
84                                      scale=False,
85                                      name=bn_name)(x)
86    if activation is not None:
87        ac_name = None if name is None else name + '_ac'
88        x = layers.Activation(activation, name=ac_name)(x)
89    return x
90
91
92def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
93    """Adds a Inception-ResNet block.
94
95    This function builds 3 types of Inception-ResNet blocks mentioned
96    in the paper, controlled by the `block_type` argument (which is the
97    block name used in the official TF-slim implementation):
98        - Inception-ResNet-A: `block_type='block35'`
99        - Inception-ResNet-B: `block_type='block17'`
100        - Inception-ResNet-C: `block_type='block8'`
101
102    # Arguments
103        x: input tensor.
104        scale: scaling factor to scale the residuals (i.e., the output of
105            passing `x` through an inception module) before adding them
106            to the shortcut branch.
107            Let `r` be the output from the residual branch,
108            the output of this block will be `x + scale * r`.
109        block_type: `'block35'`, `'block17'` or `'block8'`, determines
110            the network structure in the residual branch.
111        block_idx: an `int` used for generating layer names.
112            The Inception-ResNet blocks
113            are repeated many times in this network.
114            We use `block_idx` to identify
115            each of the repetitions. For example,
116            the first Inception-ResNet-A block
117            will have `block_type='block35', block_idx=0`,
118            and the layer names will have
119            a common prefix `'block35_0'`.
120        activation: activation function to use at the end of the block
121            (see [activations](../activations.md)).
122            When `activation=None`, no activation is applied
123            (i.e., "linear" activation: `a(x) = x`).
124
125    # Returns
126        Output tensor for the block.
127
128    # Raises
129        ValueError: if `block_type` is not one of `'block35'`,
130            `'block17'` or `'block8'`.
131    """
132    if block_type == 'block35':
133        branch_0 = conv2d_bn(x, 32, 1)
134        branch_1 = conv2d_bn(x, 32, 1)
135        branch_1 = conv2d_bn(branch_1, 32, 3)
136        branch_2 = conv2d_bn(x, 32, 1)
137        branch_2 = conv2d_bn(branch_2, 48, 3)
138        branch_2 = conv2d_bn(branch_2, 64, 3)
139        branches = [branch_0, branch_1, branch_2]
140    elif block_type == 'block17':
141        branch_0 = conv2d_bn(x, 192, 1)
142        branch_1 = conv2d_bn(x, 128, 1)
143        branch_1 = conv2d_bn(branch_1, 160, [1, 7])
144        branch_1 = conv2d_bn(branch_1, 192, [7, 1])
145        branches = [branch_0, branch_1]
146    elif block_type == 'block8':
147        branch_0 = conv2d_bn(x, 192, 1)
148        branch_1 = conv2d_bn(x, 192, 1)
149        branch_1 = conv2d_bn(branch_1, 224, [1, 3])
150        branch_1 = conv2d_bn(branch_1, 256, [3, 1])
151        branches = [branch_0, branch_1]
152    else:
153        raise ValueError('Unknown Inception-ResNet block type. '
154                         'Expects "block35", "block17" or "block8", '
155                         'but got: ' + str(block_type))
156
157    block_name = block_type + '_' + str(block_idx)
158    channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
159    mixed = layers.Concatenate(
160        axis=channel_axis, name=block_name + '_mixed')(branches)
161    up = conv2d_bn(mixed,
162                   backend.int_shape(x)[channel_axis],
163                   1,
164                   activation=None,
165                   use_bias=True,
166                   name=block_name + '_conv')
167
168    x = layers.Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
169                      output_shape=backend.int_shape(x)[1:],
170                      arguments={'scale': scale},
171                      name=block_name)([x, up])
172    if activation is not None:
173        x = layers.Activation(activation, name=block_name + '_ac')(x)
174    return x
175
176
177def InceptionResNetV2(include_top=True,
178                      weights='imagenet',
179                      input_tensor=None,
180                      input_shape=None,
181                      pooling=None,
182                      classes=1000,
183                      **kwargs):
184    """Instantiates the Inception-ResNet v2 architecture.
185
186    Optionally loads weights pre-trained on ImageNet.
187    Note that the data format convention used by the model is
188    the one specified in your Keras config at `~/.keras/keras.json`.
189
190    # Arguments
191        include_top: whether to include the fully-connected
192            layer at the top of the network.
193        weights: one of `None` (random initialization),
194              'imagenet' (pre-training on ImageNet),
195              or the path to the weights file to be loaded.
196        input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
197            to use as image input for the model.
198        input_shape: optional shape tuple, only to be specified
199            if `include_top` is `False` (otherwise the input shape
200            has to be `(299, 299, 3)` (with `'channels_last'` data format)
201            or `(3, 299, 299)` (with `'channels_first'` data format).
202            It should have exactly 3 inputs channels,
203            and width and height should be no smaller than 75.
204            E.g. `(150, 150, 3)` would be one valid value.
205        pooling: Optional pooling mode for feature extraction
206            when `include_top` is `False`.
207            - `None` means that the output of the model will be
208                the 4D tensor output of the last convolutional block.
209            - `'avg'` means that global average pooling
210                will be applied to the output of the
211                last convolutional block, and thus
212                the output of the model will be a 2D tensor.
213            - `'max'` means that global max pooling will be applied.
214        classes: optional number of classes to classify images
215            into, only to be specified if `include_top` is `True`, and
216            if no `weights` argument is specified.
217
218    # Returns
219        A Keras `Model` instance.
220
221    # Raises
222        ValueError: in case of invalid argument for `weights`,
223            or invalid input shape.
224    """
225    global backend, layers, models, keras_utils
226    backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
227
228    if not (weights in {'imagenet', None} or os.path.exists(weights)):
229        raise ValueError('The `weights` argument should be either '
230                         '`None` (random initialization), `imagenet` '
231                         '(pre-training on ImageNet), '
232                         'or the path to the weights file to be loaded.')
233
234    if weights == 'imagenet' and include_top and classes != 1000:
235        raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
236                         ' as true, `classes` should be 1000')
237
238    # Determine proper input shape
239    input_shape = _obtain_input_shape(
240        input_shape,
241        default_size=299,
242        min_size=75,
243        data_format=backend.image_data_format(),
244        require_flatten=include_top,
245        weights=weights)
246
247    if input_tensor is None:
248        img_input = layers.Input(shape=input_shape)
249    else:
250        if not backend.is_keras_tensor(input_tensor):
251            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
252        else:
253            img_input = input_tensor
254
255    # Stem block: 35 x 35 x 192
256    x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid')
257    x = conv2d_bn(x, 32, 3, padding='valid')
258    x = conv2d_bn(x, 64, 3)
259    x = layers.MaxPooling2D(3, strides=2)(x)
260    x = conv2d_bn(x, 80, 1, padding='valid')
261    x = conv2d_bn(x, 192, 3, padding='valid')
262    x = layers.MaxPooling2D(3, strides=2)(x)
263
264    # Mixed 5b (Inception-A block): 35 x 35 x 320
265    branch_0 = conv2d_bn(x, 96, 1)
266    branch_1 = conv2d_bn(x, 48, 1)
267    branch_1 = conv2d_bn(branch_1, 64, 5)
268    branch_2 = conv2d_bn(x, 64, 1)
269    branch_2 = conv2d_bn(branch_2, 96, 3)
270    branch_2 = conv2d_bn(branch_2, 96, 3)
271    branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x)
272    branch_pool = conv2d_bn(branch_pool, 64, 1)
273    branches = [branch_0, branch_1, branch_2, branch_pool]
274    channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
275    x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches)
276
277    # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
278    for block_idx in range(1, 11):
279        x = inception_resnet_block(x,
280                                   scale=0.17,
281                                   block_type='block35',
282                                   block_idx=block_idx)
283
284    # Mixed 6a (Reduction-A block): 17 x 17 x 1088
285    branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid')
286    branch_1 = conv2d_bn(x, 256, 1)
287    branch_1 = conv2d_bn(branch_1, 256, 3)
288    branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid')
289    branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x)
290    branches = [branch_0, branch_1, branch_pool]
291    x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches)
292
293    # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
294    for block_idx in range(1, 21):
295        x = inception_resnet_block(x,
296                                   scale=0.1,
297                                   block_type='block17',
298                                   block_idx=block_idx)
299
300    # Mixed 7a (Reduction-B block): 8 x 8 x 2080
301    branch_0 = conv2d_bn(x, 256, 1)
302    branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid')
303    branch_1 = conv2d_bn(x, 256, 1)
304    branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid')
305    branch_2 = conv2d_bn(x, 256, 1)
306    branch_2 = conv2d_bn(branch_2, 288, 3)
307    branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid')
308    branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x)
309    branches = [branch_0, branch_1, branch_2, branch_pool]
310    x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches)
311
312    # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
313    for block_idx in range(1, 10):
314        x = inception_resnet_block(x,
315                                   scale=0.2,
316                                   block_type='block8',
317                                   block_idx=block_idx)
318    x = inception_resnet_block(x,
319                               scale=1.,
320                               activation=None,
321                               block_type='block8',
322                               block_idx=10)
323
324    # Final convolution block: 8 x 8 x 1536
325    x = conv2d_bn(x, 1536, 1, name='conv_7b')
326
327    if include_top:
328        # Classification block
329        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
330        x = layers.Dense(classes, activation='softmax', name='predictions')(x)
331    else:
332        if pooling == 'avg':
333            x = layers.GlobalAveragePooling2D()(x)
334        elif pooling == 'max':
335            x = layers.GlobalMaxPooling2D()(x)
336
337    # Ensure that the model takes into account
338    # any potential predecessors of `input_tensor`.
339    if input_tensor is not None:
340        inputs = keras_utils.get_source_inputs(input_tensor)
341    else:
342        inputs = img_input
343
344    # Create model.
345    model = models.Model(inputs, x, name='inception_resnet_v2')
346
347    # Load weights.
348    if weights == 'imagenet':
349        if include_top:
350            fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5'
351            weights_path = keras_utils.get_file(
352                fname,
353                BASE_WEIGHT_URL + fname,
354                cache_subdir='models',
355                file_hash='e693bd0210a403b3192acc6073ad2e96')
356        else:
357            fname = ('inception_resnet_v2_weights_'
358                     'tf_dim_ordering_tf_kernels_notop.h5')
359            weights_path = keras_utils.get_file(
360                fname,
361                BASE_WEIGHT_URL + fname,
362                cache_subdir='models',
363                file_hash='d19885ff4a710c122648d3b5c3b684e4')
364        model.load_weights(weights_path)
365    elif weights is not None:
366        model.load_weights(weights)
367
368    return model
369