1"""Inception-ResNet V2 model for Keras. 2 3Model naming and structure follows TF-slim implementation 4(which has some additional layers and different number of 5filters from the original arXiv paper): 6https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py 7 8Pre-trained ImageNet weights are also converted from TF-slim, 9which can be found in: 10https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models 11 12# Reference 13- [Inception-v4, Inception-ResNet and the Impact of 14 Residual Connections on Learning](https://arxiv.org/abs/1602.07261) (AAAI 2017) 15 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from . import get_submodules_from_kwargs 24from . import imagenet_utils 25from .imagenet_utils import decode_predictions 26from .imagenet_utils import _obtain_input_shape 27 28 29BASE_WEIGHT_URL = ('https://github.com/fchollet/deep-learning-models/' 30 'releases/download/v0.7/') 31 32backend = None 33layers = None 34models = None 35keras_utils = None 36 37 38def preprocess_input(x, **kwargs): 39 """Preprocesses a numpy array encoding a batch of images. 40 41 # Arguments 42 x: a 4D numpy array consists of RGB values within [0, 255]. 43 44 # Returns 45 Preprocessed array. 46 """ 47 return imagenet_utils.preprocess_input(x, mode='tf', **kwargs) 48 49 50def conv2d_bn(x, 51 filters, 52 kernel_size, 53 strides=1, 54 padding='same', 55 activation='relu', 56 use_bias=False, 57 name=None): 58 """Utility function to apply conv + BN. 59 60 # Arguments 61 x: input tensor. 62 filters: filters in `Conv2D`. 63 kernel_size: kernel size as in `Conv2D`. 64 strides: strides in `Conv2D`. 65 padding: padding mode in `Conv2D`. 66 activation: activation in `Conv2D`. 67 use_bias: whether to use a bias in `Conv2D`. 68 name: name of the ops; will become `name + '_ac'` for the activation 69 and `name + '_bn'` for the batch norm layer. 70 71 # Returns 72 Output tensor after applying `Conv2D` and `BatchNormalization`. 73 """ 74 x = layers.Conv2D(filters, 75 kernel_size, 76 strides=strides, 77 padding=padding, 78 use_bias=use_bias, 79 name=name)(x) 80 if not use_bias: 81 bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3 82 bn_name = None if name is None else name + '_bn' 83 x = layers.BatchNormalization(axis=bn_axis, 84 scale=False, 85 name=bn_name)(x) 86 if activation is not None: 87 ac_name = None if name is None else name + '_ac' 88 x = layers.Activation(activation, name=ac_name)(x) 89 return x 90 91 92def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): 93 """Adds a Inception-ResNet block. 94 95 This function builds 3 types of Inception-ResNet blocks mentioned 96 in the paper, controlled by the `block_type` argument (which is the 97 block name used in the official TF-slim implementation): 98 - Inception-ResNet-A: `block_type='block35'` 99 - Inception-ResNet-B: `block_type='block17'` 100 - Inception-ResNet-C: `block_type='block8'` 101 102 # Arguments 103 x: input tensor. 104 scale: scaling factor to scale the residuals (i.e., the output of 105 passing `x` through an inception module) before adding them 106 to the shortcut branch. 107 Let `r` be the output from the residual branch, 108 the output of this block will be `x + scale * r`. 109 block_type: `'block35'`, `'block17'` or `'block8'`, determines 110 the network structure in the residual branch. 111 block_idx: an `int` used for generating layer names. 112 The Inception-ResNet blocks 113 are repeated many times in this network. 114 We use `block_idx` to identify 115 each of the repetitions. For example, 116 the first Inception-ResNet-A block 117 will have `block_type='block35', block_idx=0`, 118 and the layer names will have 119 a common prefix `'block35_0'`. 120 activation: activation function to use at the end of the block 121 (see [activations](../activations.md)). 122 When `activation=None`, no activation is applied 123 (i.e., "linear" activation: `a(x) = x`). 124 125 # Returns 126 Output tensor for the block. 127 128 # Raises 129 ValueError: if `block_type` is not one of `'block35'`, 130 `'block17'` or `'block8'`. 131 """ 132 if block_type == 'block35': 133 branch_0 = conv2d_bn(x, 32, 1) 134 branch_1 = conv2d_bn(x, 32, 1) 135 branch_1 = conv2d_bn(branch_1, 32, 3) 136 branch_2 = conv2d_bn(x, 32, 1) 137 branch_2 = conv2d_bn(branch_2, 48, 3) 138 branch_2 = conv2d_bn(branch_2, 64, 3) 139 branches = [branch_0, branch_1, branch_2] 140 elif block_type == 'block17': 141 branch_0 = conv2d_bn(x, 192, 1) 142 branch_1 = conv2d_bn(x, 128, 1) 143 branch_1 = conv2d_bn(branch_1, 160, [1, 7]) 144 branch_1 = conv2d_bn(branch_1, 192, [7, 1]) 145 branches = [branch_0, branch_1] 146 elif block_type == 'block8': 147 branch_0 = conv2d_bn(x, 192, 1) 148 branch_1 = conv2d_bn(x, 192, 1) 149 branch_1 = conv2d_bn(branch_1, 224, [1, 3]) 150 branch_1 = conv2d_bn(branch_1, 256, [3, 1]) 151 branches = [branch_0, branch_1] 152 else: 153 raise ValueError('Unknown Inception-ResNet block type. ' 154 'Expects "block35", "block17" or "block8", ' 155 'but got: ' + str(block_type)) 156 157 block_name = block_type + '_' + str(block_idx) 158 channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 159 mixed = layers.Concatenate( 160 axis=channel_axis, name=block_name + '_mixed')(branches) 161 up = conv2d_bn(mixed, 162 backend.int_shape(x)[channel_axis], 163 1, 164 activation=None, 165 use_bias=True, 166 name=block_name + '_conv') 167 168 x = layers.Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale, 169 output_shape=backend.int_shape(x)[1:], 170 arguments={'scale': scale}, 171 name=block_name)([x, up]) 172 if activation is not None: 173 x = layers.Activation(activation, name=block_name + '_ac')(x) 174 return x 175 176 177def InceptionResNetV2(include_top=True, 178 weights='imagenet', 179 input_tensor=None, 180 input_shape=None, 181 pooling=None, 182 classes=1000, 183 **kwargs): 184 """Instantiates the Inception-ResNet v2 architecture. 185 186 Optionally loads weights pre-trained on ImageNet. 187 Note that the data format convention used by the model is 188 the one specified in your Keras config at `~/.keras/keras.json`. 189 190 # Arguments 191 include_top: whether to include the fully-connected 192 layer at the top of the network. 193 weights: one of `None` (random initialization), 194 'imagenet' (pre-training on ImageNet), 195 or the path to the weights file to be loaded. 196 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 197 to use as image input for the model. 198 input_shape: optional shape tuple, only to be specified 199 if `include_top` is `False` (otherwise the input shape 200 has to be `(299, 299, 3)` (with `'channels_last'` data format) 201 or `(3, 299, 299)` (with `'channels_first'` data format). 202 It should have exactly 3 inputs channels, 203 and width and height should be no smaller than 75. 204 E.g. `(150, 150, 3)` would be one valid value. 205 pooling: Optional pooling mode for feature extraction 206 when `include_top` is `False`. 207 - `None` means that the output of the model will be 208 the 4D tensor output of the last convolutional block. 209 - `'avg'` means that global average pooling 210 will be applied to the output of the 211 last convolutional block, and thus 212 the output of the model will be a 2D tensor. 213 - `'max'` means that global max pooling will be applied. 214 classes: optional number of classes to classify images 215 into, only to be specified if `include_top` is `True`, and 216 if no `weights` argument is specified. 217 218 # Returns 219 A Keras `Model` instance. 220 221 # Raises 222 ValueError: in case of invalid argument for `weights`, 223 or invalid input shape. 224 """ 225 global backend, layers, models, keras_utils 226 backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs) 227 228 if not (weights in {'imagenet', None} or os.path.exists(weights)): 229 raise ValueError('The `weights` argument should be either ' 230 '`None` (random initialization), `imagenet` ' 231 '(pre-training on ImageNet), ' 232 'or the path to the weights file to be loaded.') 233 234 if weights == 'imagenet' and include_top and classes != 1000: 235 raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 236 ' as true, `classes` should be 1000') 237 238 # Determine proper input shape 239 input_shape = _obtain_input_shape( 240 input_shape, 241 default_size=299, 242 min_size=75, 243 data_format=backend.image_data_format(), 244 require_flatten=include_top, 245 weights=weights) 246 247 if input_tensor is None: 248 img_input = layers.Input(shape=input_shape) 249 else: 250 if not backend.is_keras_tensor(input_tensor): 251 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 252 else: 253 img_input = input_tensor 254 255 # Stem block: 35 x 35 x 192 256 x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid') 257 x = conv2d_bn(x, 32, 3, padding='valid') 258 x = conv2d_bn(x, 64, 3) 259 x = layers.MaxPooling2D(3, strides=2)(x) 260 x = conv2d_bn(x, 80, 1, padding='valid') 261 x = conv2d_bn(x, 192, 3, padding='valid') 262 x = layers.MaxPooling2D(3, strides=2)(x) 263 264 # Mixed 5b (Inception-A block): 35 x 35 x 320 265 branch_0 = conv2d_bn(x, 96, 1) 266 branch_1 = conv2d_bn(x, 48, 1) 267 branch_1 = conv2d_bn(branch_1, 64, 5) 268 branch_2 = conv2d_bn(x, 64, 1) 269 branch_2 = conv2d_bn(branch_2, 96, 3) 270 branch_2 = conv2d_bn(branch_2, 96, 3) 271 branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x) 272 branch_pool = conv2d_bn(branch_pool, 64, 1) 273 branches = [branch_0, branch_1, branch_2, branch_pool] 274 channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 275 x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches) 276 277 # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320 278 for block_idx in range(1, 11): 279 x = inception_resnet_block(x, 280 scale=0.17, 281 block_type='block35', 282 block_idx=block_idx) 283 284 # Mixed 6a (Reduction-A block): 17 x 17 x 1088 285 branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid') 286 branch_1 = conv2d_bn(x, 256, 1) 287 branch_1 = conv2d_bn(branch_1, 256, 3) 288 branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid') 289 branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x) 290 branches = [branch_0, branch_1, branch_pool] 291 x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches) 292 293 # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088 294 for block_idx in range(1, 21): 295 x = inception_resnet_block(x, 296 scale=0.1, 297 block_type='block17', 298 block_idx=block_idx) 299 300 # Mixed 7a (Reduction-B block): 8 x 8 x 2080 301 branch_0 = conv2d_bn(x, 256, 1) 302 branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid') 303 branch_1 = conv2d_bn(x, 256, 1) 304 branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid') 305 branch_2 = conv2d_bn(x, 256, 1) 306 branch_2 = conv2d_bn(branch_2, 288, 3) 307 branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid') 308 branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x) 309 branches = [branch_0, branch_1, branch_2, branch_pool] 310 x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches) 311 312 # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080 313 for block_idx in range(1, 10): 314 x = inception_resnet_block(x, 315 scale=0.2, 316 block_type='block8', 317 block_idx=block_idx) 318 x = inception_resnet_block(x, 319 scale=1., 320 activation=None, 321 block_type='block8', 322 block_idx=10) 323 324 # Final convolution block: 8 x 8 x 1536 325 x = conv2d_bn(x, 1536, 1, name='conv_7b') 326 327 if include_top: 328 # Classification block 329 x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 330 x = layers.Dense(classes, activation='softmax', name='predictions')(x) 331 else: 332 if pooling == 'avg': 333 x = layers.GlobalAveragePooling2D()(x) 334 elif pooling == 'max': 335 x = layers.GlobalMaxPooling2D()(x) 336 337 # Ensure that the model takes into account 338 # any potential predecessors of `input_tensor`. 339 if input_tensor is not None: 340 inputs = keras_utils.get_source_inputs(input_tensor) 341 else: 342 inputs = img_input 343 344 # Create model. 345 model = models.Model(inputs, x, name='inception_resnet_v2') 346 347 # Load weights. 348 if weights == 'imagenet': 349 if include_top: 350 fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5' 351 weights_path = keras_utils.get_file( 352 fname, 353 BASE_WEIGHT_URL + fname, 354 cache_subdir='models', 355 file_hash='e693bd0210a403b3192acc6073ad2e96') 356 else: 357 fname = ('inception_resnet_v2_weights_' 358 'tf_dim_ordering_tf_kernels_notop.h5') 359 weights_path = keras_utils.get_file( 360 fname, 361 BASE_WEIGHT_URL + fname, 362 cache_subdir='models', 363 file_hash='d19885ff4a710c122648d3b5c3b684e4') 364 model.load_weights(weights_path) 365 elif weights is not None: 366 model.load_weights(weights) 367 368 return model 369