1"""Inception V3 model for Keras.
2
3Note that the input image format for this model is different than for
4the VGG16 and ResNet models (299x299 instead of 224x224),
5and that the input preprocessing function is also different (same as Xception).
6
7# Reference
8
9- [Rethinking the Inception Architecture for Computer Vision](
10    http://arxiv.org/abs/1512.00567) (CVPR 2016)
11
12"""
13from __future__ import absolute_import
14from __future__ import division
15from __future__ import print_function
16
17import os
18
19from . import get_submodules_from_kwargs
20from . import imagenet_utils
21from .imagenet_utils import decode_predictions
22from .imagenet_utils import _obtain_input_shape
23
24
25WEIGHTS_PATH = (
26    'https://github.com/fchollet/deep-learning-models/'
27    'releases/download/v0.5/'
28    'inception_v3_weights_tf_dim_ordering_tf_kernels.h5')
29WEIGHTS_PATH_NO_TOP = (
30    'https://github.com/fchollet/deep-learning-models/'
31    'releases/download/v0.5/'
32    'inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5')
33
34backend = None
35layers = None
36models = None
37keras_utils = None
38
39
40def conv2d_bn(x,
41              filters,
42              num_row,
43              num_col,
44              padding='same',
45              strides=(1, 1),
46              name=None):
47    """Utility function to apply conv + BN.
48
49    # Arguments
50        x: input tensor.
51        filters: filters in `Conv2D`.
52        num_row: height of the convolution kernel.
53        num_col: width of the convolution kernel.
54        padding: padding mode in `Conv2D`.
55        strides: strides in `Conv2D`.
56        name: name of the ops; will become `name + '_conv'`
57            for the convolution and `name + '_bn'` for the
58            batch norm layer.
59
60    # Returns
61        Output tensor after applying `Conv2D` and `BatchNormalization`.
62    """
63    if name is not None:
64        bn_name = name + '_bn'
65        conv_name = name + '_conv'
66    else:
67        bn_name = None
68        conv_name = None
69    if backend.image_data_format() == 'channels_first':
70        bn_axis = 1
71    else:
72        bn_axis = 3
73    x = layers.Conv2D(
74        filters, (num_row, num_col),
75        strides=strides,
76        padding=padding,
77        use_bias=False,
78        name=conv_name)(x)
79    x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
80    x = layers.Activation('relu', name=name)(x)
81    return x
82
83
84def InceptionV3(include_top=True,
85                weights='imagenet',
86                input_tensor=None,
87                input_shape=None,
88                pooling=None,
89                classes=1000,
90                **kwargs):
91    """Instantiates the Inception v3 architecture.
92
93    Optionally loads weights pre-trained on ImageNet.
94    Note that the data format convention used by the model is
95    the one specified in your Keras config at `~/.keras/keras.json`.
96
97    # Arguments
98        include_top: whether to include the fully-connected
99            layer at the top of the network.
100        weights: one of `None` (random initialization),
101              'imagenet' (pre-training on ImageNet),
102              or the path to the weights file to be loaded.
103        input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
104            to use as image input for the model.
105        input_shape: optional shape tuple, only to be specified
106            if `include_top` is False (otherwise the input shape
107            has to be `(299, 299, 3)` (with `channels_last` data format)
108            or `(3, 299, 299)` (with `channels_first` data format).
109            It should have exactly 3 inputs channels,
110            and width and height should be no smaller than 75.
111            E.g. `(150, 150, 3)` would be one valid value.
112        pooling: Optional pooling mode for feature extraction
113            when `include_top` is `False`.
114            - `None` means that the output of the model will be
115                the 4D tensor output of the
116                last convolutional block.
117            - `avg` means that global average pooling
118                will be applied to the output of the
119                last convolutional block, and thus
120                the output of the model will be a 2D tensor.
121            - `max` means that global max pooling will
122                be applied.
123        classes: optional number of classes to classify images
124            into, only to be specified if `include_top` is True, and
125            if no `weights` argument is specified.
126
127    # Returns
128        A Keras model instance.
129
130    # Raises
131        ValueError: in case of invalid argument for `weights`,
132            or invalid input shape.
133    """
134    global backend, layers, models, keras_utils
135    backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
136
137    if not (weights in {'imagenet', None} or os.path.exists(weights)):
138        raise ValueError('The `weights` argument should be either '
139                         '`None` (random initialization), `imagenet` '
140                         '(pre-training on ImageNet), '
141                         'or the path to the weights file to be loaded.')
142
143    if weights == 'imagenet' and include_top and classes != 1000:
144        raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
145                         ' as true, `classes` should be 1000')
146
147    # Determine proper input shape
148    input_shape = _obtain_input_shape(
149        input_shape,
150        default_size=299,
151        min_size=75,
152        data_format=backend.image_data_format(),
153        require_flatten=include_top,
154        weights=weights)
155
156    if input_tensor is None:
157        img_input = layers.Input(shape=input_shape)
158    else:
159        if not backend.is_keras_tensor(input_tensor):
160            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
161        else:
162            img_input = input_tensor
163
164    if backend.image_data_format() == 'channels_first':
165        channel_axis = 1
166    else:
167        channel_axis = 3
168
169    x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding='valid')
170    x = conv2d_bn(x, 32, 3, 3, padding='valid')
171    x = conv2d_bn(x, 64, 3, 3)
172    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
173
174    x = conv2d_bn(x, 80, 1, 1, padding='valid')
175    x = conv2d_bn(x, 192, 3, 3, padding='valid')
176    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
177
178    # mixed 0: 35 x 35 x 256
179    branch1x1 = conv2d_bn(x, 64, 1, 1)
180
181    branch5x5 = conv2d_bn(x, 48, 1, 1)
182    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
183
184    branch3x3dbl = conv2d_bn(x, 64, 1, 1)
185    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
186    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
187
188    branch_pool = layers.AveragePooling2D((3, 3),
189                                          strides=(1, 1),
190                                          padding='same')(x)
191    branch_pool = conv2d_bn(branch_pool, 32, 1, 1)
192    x = layers.concatenate(
193        [branch1x1, branch5x5, branch3x3dbl, branch_pool],
194        axis=channel_axis,
195        name='mixed0')
196
197    # mixed 1: 35 x 35 x 288
198    branch1x1 = conv2d_bn(x, 64, 1, 1)
199
200    branch5x5 = conv2d_bn(x, 48, 1, 1)
201    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
202
203    branch3x3dbl = conv2d_bn(x, 64, 1, 1)
204    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
205    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
206
207    branch_pool = layers.AveragePooling2D((3, 3),
208                                          strides=(1, 1),
209                                          padding='same')(x)
210    branch_pool = conv2d_bn(branch_pool, 64, 1, 1)
211    x = layers.concatenate(
212        [branch1x1, branch5x5, branch3x3dbl, branch_pool],
213        axis=channel_axis,
214        name='mixed1')
215
216    # mixed 2: 35 x 35 x 288
217    branch1x1 = conv2d_bn(x, 64, 1, 1)
218
219    branch5x5 = conv2d_bn(x, 48, 1, 1)
220    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
221
222    branch3x3dbl = conv2d_bn(x, 64, 1, 1)
223    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
224    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
225
226    branch_pool = layers.AveragePooling2D((3, 3),
227                                          strides=(1, 1),
228                                          padding='same')(x)
229    branch_pool = conv2d_bn(branch_pool, 64, 1, 1)
230    x = layers.concatenate(
231        [branch1x1, branch5x5, branch3x3dbl, branch_pool],
232        axis=channel_axis,
233        name='mixed2')
234
235    # mixed 3: 17 x 17 x 768
236    branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding='valid')
237
238    branch3x3dbl = conv2d_bn(x, 64, 1, 1)
239    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
240    branch3x3dbl = conv2d_bn(
241        branch3x3dbl, 96, 3, 3, strides=(2, 2), padding='valid')
242
243    branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
244    x = layers.concatenate(
245        [branch3x3, branch3x3dbl, branch_pool],
246        axis=channel_axis,
247        name='mixed3')
248
249    # mixed 4: 17 x 17 x 768
250    branch1x1 = conv2d_bn(x, 192, 1, 1)
251
252    branch7x7 = conv2d_bn(x, 128, 1, 1)
253    branch7x7 = conv2d_bn(branch7x7, 128, 1, 7)
254    branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
255
256    branch7x7dbl = conv2d_bn(x, 128, 1, 1)
257    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)
258    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7)
259    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)
260    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
261
262    branch_pool = layers.AveragePooling2D((3, 3),
263                                          strides=(1, 1),
264                                          padding='same')(x)
265    branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
266    x = layers.concatenate(
267        [branch1x1, branch7x7, branch7x7dbl, branch_pool],
268        axis=channel_axis,
269        name='mixed4')
270
271    # mixed 5, 6: 17 x 17 x 768
272    for i in range(2):
273        branch1x1 = conv2d_bn(x, 192, 1, 1)
274
275        branch7x7 = conv2d_bn(x, 160, 1, 1)
276        branch7x7 = conv2d_bn(branch7x7, 160, 1, 7)
277        branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
278
279        branch7x7dbl = conv2d_bn(x, 160, 1, 1)
280        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)
281        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7)
282        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)
283        branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
284
285        branch_pool = layers.AveragePooling2D(
286            (3, 3), strides=(1, 1), padding='same')(x)
287        branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
288        x = layers.concatenate(
289            [branch1x1, branch7x7, branch7x7dbl, branch_pool],
290            axis=channel_axis,
291            name='mixed' + str(5 + i))
292
293    # mixed 7: 17 x 17 x 768
294    branch1x1 = conv2d_bn(x, 192, 1, 1)
295
296    branch7x7 = conv2d_bn(x, 192, 1, 1)
297    branch7x7 = conv2d_bn(branch7x7, 192, 1, 7)
298    branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
299
300    branch7x7dbl = conv2d_bn(x, 192, 1, 1)
301    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)
302    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
303    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)
304    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
305
306    branch_pool = layers.AveragePooling2D((3, 3),
307                                          strides=(1, 1),
308                                          padding='same')(x)
309    branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
310    x = layers.concatenate(
311        [branch1x1, branch7x7, branch7x7dbl, branch_pool],
312        axis=channel_axis,
313        name='mixed7')
314
315    # mixed 8: 8 x 8 x 1280
316    branch3x3 = conv2d_bn(x, 192, 1, 1)
317    branch3x3 = conv2d_bn(branch3x3, 320, 3, 3,
318                          strides=(2, 2), padding='valid')
319
320    branch7x7x3 = conv2d_bn(x, 192, 1, 1)
321    branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7)
322    branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1)
323    branch7x7x3 = conv2d_bn(
324        branch7x7x3, 192, 3, 3, strides=(2, 2), padding='valid')
325
326    branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
327    x = layers.concatenate(
328        [branch3x3, branch7x7x3, branch_pool],
329        axis=channel_axis,
330        name='mixed8')
331
332    # mixed 9: 8 x 8 x 2048
333    for i in range(2):
334        branch1x1 = conv2d_bn(x, 320, 1, 1)
335
336        branch3x3 = conv2d_bn(x, 384, 1, 1)
337        branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3)
338        branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1)
339        branch3x3 = layers.concatenate(
340            [branch3x3_1, branch3x3_2],
341            axis=channel_axis,
342            name='mixed9_' + str(i))
343
344        branch3x3dbl = conv2d_bn(x, 448, 1, 1)
345        branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3)
346        branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3)
347        branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1)
348        branch3x3dbl = layers.concatenate(
349            [branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis)
350
351        branch_pool = layers.AveragePooling2D(
352            (3, 3), strides=(1, 1), padding='same')(x)
353        branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
354        x = layers.concatenate(
355            [branch1x1, branch3x3, branch3x3dbl, branch_pool],
356            axis=channel_axis,
357            name='mixed' + str(9 + i))
358    if include_top:
359        # Classification block
360        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
361        x = layers.Dense(classes, activation='softmax', name='predictions')(x)
362    else:
363        if pooling == 'avg':
364            x = layers.GlobalAveragePooling2D()(x)
365        elif pooling == 'max':
366            x = layers.GlobalMaxPooling2D()(x)
367
368    # Ensure that the model takes into account
369    # any potential predecessors of `input_tensor`.
370    if input_tensor is not None:
371        inputs = keras_utils.get_source_inputs(input_tensor)
372    else:
373        inputs = img_input
374    # Create model.
375    model = models.Model(inputs, x, name='inception_v3')
376
377    # Load weights.
378    if weights == 'imagenet':
379        if include_top:
380            weights_path = keras_utils.get_file(
381                'inception_v3_weights_tf_dim_ordering_tf_kernels.h5',
382                WEIGHTS_PATH,
383                cache_subdir='models',
384                file_hash='9a0d58056eeedaa3f26cb7ebd46da564')
385        else:
386            weights_path = keras_utils.get_file(
387                'inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5',
388                WEIGHTS_PATH_NO_TOP,
389                cache_subdir='models',
390                file_hash='bcbd6486424b2319ff4ef7d526e38f63')
391        model.load_weights(weights_path)
392    elif weights is not None:
393        model.load_weights(weights)
394
395    return model
396
397
398def preprocess_input(x, **kwargs):
399    """Preprocesses a numpy array encoding a batch of images.
400
401    # Arguments
402        x: a 4D numpy array consists of RGB values within [0, 255].
403
404    # Returns
405        Preprocessed array.
406    """
407    return imagenet_utils.preprocess_input(x, mode='tf', **kwargs)
408