1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the definition for the NASNet classification networks. 16 17Paper: https://arxiv.org/abs/1707.07012 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import tensorflow as tf 24 25from . import nasnet_utils 26 27arg_scope = tf.contrib.framework.arg_scope 28slim = tf.contrib.slim 29 30 31# Notes for training NASNet Cifar Model 32# ------------------------------------- 33# batch_size: 32 34# learning rate: 0.025 35# cosine (single period) learning rate decay 36# auxiliary head loss weighting: 0.4 37# clip global norm of all gradients by 5 38def _cifar_config(is_training=True): 39 drop_path_keep_prob = 1.0 if not is_training else 0.6 40 return tf.contrib.training.HParams( 41 stem_multiplier=3.0, 42 drop_path_keep_prob=drop_path_keep_prob, 43 num_cells=18, 44 use_aux_head=1, 45 num_conv_filters=32, 46 dense_dropout_keep_prob=1.0, 47 filter_scaling_rate=2.0, 48 num_reduction_layers=2, 49 data_format='NHWC', 50 skip_reduction_layer_input=0, 51 # 600 epochs with a batch size of 32 52 # This is used for the drop path probabilities since it needs to increase 53 # the drop out probability over the course of training. 54 total_training_steps=937500, 55 ) 56 57 58# Notes for training large NASNet model on ImageNet 59# ------------------------------------- 60# batch size (per replica): 16 61# learning rate: 0.015 * 100 62# learning rate decay factor: 0.97 63# num epochs per decay: 2.4 64# sync sgd with 100 replicas 65# auxiliary head loss weighting: 0.4 66# label smoothing: 0.1 67# clip global norm of all gradients by 10 68def _large_imagenet_config(is_training=True): 69 drop_path_keep_prob = 1.0 if not is_training else 0.7 70 return tf.contrib.training.HParams( 71 stem_multiplier=3.0, 72 dense_dropout_keep_prob=0.5, 73 num_cells=18, 74 filter_scaling_rate=2.0, 75 num_conv_filters=168, 76 drop_path_keep_prob=drop_path_keep_prob, 77 use_aux_head=1, 78 num_reduction_layers=2, 79 data_format='NHWC', 80 skip_reduction_layer_input=1, 81 total_training_steps=250000, 82 ) 83 84 85# Notes for training the mobile NASNet ImageNet model 86# ------------------------------------- 87# batch size (per replica): 32 88# learning rate: 0.04 * 50 89# learning rate scaling factor: 0.97 90# num epochs per decay: 2.4 91# sync sgd with 50 replicas 92# auxiliary head weighting: 0.4 93# label smoothing: 0.1 94# clip global norm of all gradients by 10 95def _mobile_imagenet_config(): 96 return tf.contrib.training.HParams( 97 stem_multiplier=1.0, 98 dense_dropout_keep_prob=0.5, 99 num_cells=12, 100 filter_scaling_rate=2.0, 101 drop_path_keep_prob=1.0, 102 num_conv_filters=44, 103 use_aux_head=1, 104 num_reduction_layers=2, 105 data_format='NHWC', 106 skip_reduction_layer_input=0, 107 total_training_steps=250000, 108 ) 109 110 111def nasnet_cifar_arg_scope(weight_decay=5e-4, 112 batch_norm_decay=0.9, 113 batch_norm_epsilon=1e-5): 114 """Defines the default arg scope for the NASNet-A Cifar model. 115 116 Args: 117 weight_decay: The weight decay to use for regularizing the model. 118 batch_norm_decay: Decay for batch norm moving average. 119 batch_norm_epsilon: Small float added to variance to avoid dividing by zero 120 in batch norm. 121 122 Returns: 123 An `arg_scope` to use for the NASNet Cifar Model. 124 """ 125 batch_norm_params = { 126 # Decay for the moving averages. 127 'decay': batch_norm_decay, 128 # epsilon to prevent 0s in variance. 129 'epsilon': batch_norm_epsilon, 130 'scale': True, 131 'fused': True, 132 } 133 weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 134 weights_initializer = tf.contrib.layers.variance_scaling_initializer( 135 mode='FAN_OUT') 136 with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d], 137 weights_regularizer=weights_regularizer, 138 weights_initializer=weights_initializer): 139 with arg_scope([slim.fully_connected], 140 activation_fn=None, scope='FC'): 141 with arg_scope([slim.conv2d, slim.separable_conv2d], 142 activation_fn=None, biases_initializer=None): 143 with arg_scope([slim.batch_norm], **batch_norm_params) as sc: 144 return sc 145 146 147def nasnet_mobile_arg_scope(weight_decay=4e-5, 148 batch_norm_decay=0.9997, 149 batch_norm_epsilon=1e-3): 150 """Defines the default arg scope for the NASNet-A Mobile ImageNet model. 151 152 Args: 153 weight_decay: The weight decay to use for regularizing the model. 154 batch_norm_decay: Decay for batch norm moving average. 155 batch_norm_epsilon: Small float added to variance to avoid dividing by zero 156 in batch norm. 157 158 Returns: 159 An `arg_scope` to use for the NASNet Mobile Model. 160 """ 161 batch_norm_params = { 162 # Decay for the moving averages. 163 'decay': batch_norm_decay, 164 # epsilon to prevent 0s in variance. 165 'epsilon': batch_norm_epsilon, 166 'scale': True, 167 'fused': True, 168 } 169 weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 170 weights_initializer = tf.contrib.layers.variance_scaling_initializer( 171 mode='FAN_OUT') 172 with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d], 173 weights_regularizer=weights_regularizer, 174 weights_initializer=weights_initializer): 175 with arg_scope([slim.fully_connected], 176 activation_fn=None, scope='FC'): 177 with arg_scope([slim.conv2d, slim.separable_conv2d], 178 activation_fn=None, biases_initializer=None): 179 with arg_scope([slim.batch_norm], **batch_norm_params) as sc: 180 return sc 181 182 183def nasnet_large_arg_scope(weight_decay=5e-5, 184 batch_norm_decay=0.9997, 185 batch_norm_epsilon=1e-3): 186 """Defines the default arg scope for the NASNet-A Large ImageNet model. 187 188 Args: 189 weight_decay: The weight decay to use for regularizing the model. 190 batch_norm_decay: Decay for batch norm moving average. 191 batch_norm_epsilon: Small float added to variance to avoid dividing by zero 192 in batch norm. 193 194 Returns: 195 An `arg_scope` to use for the NASNet Large Model. 196 """ 197 batch_norm_params = { 198 # Decay for the moving averages. 199 'decay': batch_norm_decay, 200 # epsilon to prevent 0s in variance. 201 'epsilon': batch_norm_epsilon, 202 'scale': True, 203 'fused': True, 204 } 205 weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 206 weights_initializer = tf.contrib.layers.variance_scaling_initializer( 207 mode='FAN_OUT') 208 with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d], 209 weights_regularizer=weights_regularizer, 210 weights_initializer=weights_initializer): 211 with arg_scope([slim.fully_connected], 212 activation_fn=None, scope='FC'): 213 with arg_scope([slim.conv2d, slim.separable_conv2d], 214 activation_fn=None, biases_initializer=None): 215 with arg_scope([slim.batch_norm], **batch_norm_params) as sc: 216 return sc 217 218 219def _build_aux_head(net, end_points, num_classes, hparams, scope): 220 """Auxiliary head used for all models across all datasets.""" 221 with tf.variable_scope(scope): 222 aux_logits = tf.identity(net) 223 with tf.variable_scope('aux_logits'): 224 aux_logits = slim.avg_pool2d( 225 aux_logits, [5, 5], stride=3, padding='VALID') 226 aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj') 227 aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0') 228 aux_logits = tf.nn.relu(aux_logits) 229 # Shape of feature map before the final layer. 230 shape = aux_logits.shape 231 if hparams.data_format == 'NHWC': 232 shape = shape[1:3] 233 else: 234 shape = shape[2:4] 235 aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID') 236 aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1') 237 aux_logits = tf.nn.relu(aux_logits) 238 aux_logits = tf.contrib.layers.flatten(aux_logits) 239 aux_logits = slim.fully_connected(aux_logits, num_classes) 240 end_points['AuxLogits'] = aux_logits 241 242 243def _imagenet_stem(inputs, hparams, stem_cell): 244 """Stem used for models trained on ImageNet.""" 245 num_stem_cells = 2 246 247 # 149 x 149 x 32 248 num_stem_filters = int(32 * hparams.stem_multiplier) 249 net = slim.conv2d( 250 inputs, num_stem_filters, [3, 3], stride=2, scope='conv0', 251 padding='VALID') 252 net = slim.batch_norm(net, scope='conv0_bn') 253 254 # Run the reduction cells 255 cell_outputs = [None, net] 256 filter_scaling = 1.0 / (hparams.filter_scaling_rate**num_stem_cells) 257 for cell_num in range(num_stem_cells): 258 net = stem_cell( 259 net, 260 scope='cell_stem_{}'.format(cell_num), 261 filter_scaling=filter_scaling, 262 stride=2, 263 prev_layer=cell_outputs[-2], 264 cell_num=cell_num) 265 cell_outputs.append(net) 266 filter_scaling *= hparams.filter_scaling_rate 267 return net, cell_outputs 268 269 270def _cifar_stem(inputs, hparams): 271 """Stem used for models trained on Cifar.""" 272 num_stem_filters = int(hparams.num_conv_filters * hparams.stem_multiplier) 273 net = slim.conv2d( 274 inputs, 275 num_stem_filters, 276 3, 277 scope='l1_stem_3x3') 278 net = slim.batch_norm(net, scope='l1_stem_bn') 279 return net, [None, net] 280 281 282def build_nasnet_cifar( 283 images, num_classes, is_training=True): 284 """Build NASNet model for the Cifar Dataset.""" 285 hparams = _cifar_config(is_training=is_training) 286 287 if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': 288 tf.logging.info('A GPU is available on the machine, consider using NCHW ' 289 'data format for increased speed on GPU.') 290 291 if hparams.data_format == 'NCHW': 292 images = tf.transpose(images, [0, 3, 1, 2]) 293 294 # Calculate the total number of cells in the network 295 # Add 2 for the reduction cells 296 total_num_cells = hparams.num_cells + 2 297 298 normal_cell = nasnet_utils.NasNetANormalCell( 299 hparams.num_conv_filters, hparams.drop_path_keep_prob, 300 total_num_cells, hparams.total_training_steps) 301 reduction_cell = nasnet_utils.NasNetAReductionCell( 302 hparams.num_conv_filters, hparams.drop_path_keep_prob, 303 total_num_cells, hparams.total_training_steps) 304 with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], 305 is_training=is_training): 306 with arg_scope([slim.avg_pool2d, 307 slim.max_pool2d, 308 slim.conv2d, 309 slim.batch_norm, 310 slim.separable_conv2d, 311 nasnet_utils.factorized_reduction, 312 nasnet_utils.global_avg_pool, 313 nasnet_utils.get_channel_index, 314 nasnet_utils.get_channel_dim], 315 data_format=hparams.data_format): 316 return _build_nasnet_base(images, 317 normal_cell=normal_cell, 318 reduction_cell=reduction_cell, 319 num_classes=num_classes, 320 hparams=hparams, 321 is_training=is_training, 322 stem_type='cifar') 323build_nasnet_cifar.default_image_size = 32 324 325 326def build_nasnet_mobile(images, num_classes, 327 is_training=True, 328 final_endpoint=None): 329 """Build NASNet Mobile model for the ImageNet Dataset.""" 330 hparams = _mobile_imagenet_config() 331 332 if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': 333 tf.logging.info('A GPU is available on the machine, consider using NCHW ' 334 'data format for increased speed on GPU.') 335 336 if hparams.data_format == 'NCHW': 337 images = tf.transpose(images, [0, 3, 1, 2]) 338 339 # Calculate the total number of cells in the network 340 # Add 2 for the reduction cells 341 total_num_cells = hparams.num_cells + 2 342 # If ImageNet, then add an additional two for the stem cells 343 total_num_cells += 2 344 345 normal_cell = nasnet_utils.NasNetANormalCell( 346 hparams.num_conv_filters, hparams.drop_path_keep_prob, 347 total_num_cells, hparams.total_training_steps) 348 reduction_cell = nasnet_utils.NasNetAReductionCell( 349 hparams.num_conv_filters, hparams.drop_path_keep_prob, 350 total_num_cells, hparams.total_training_steps) 351 with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], 352 is_training=is_training): 353 with arg_scope([slim.avg_pool2d, 354 slim.max_pool2d, 355 slim.conv2d, 356 slim.batch_norm, 357 slim.separable_conv2d, 358 nasnet_utils.factorized_reduction, 359 nasnet_utils.global_avg_pool, 360 nasnet_utils.get_channel_index, 361 nasnet_utils.get_channel_dim], 362 data_format=hparams.data_format): 363 return _build_nasnet_base(images, 364 normal_cell=normal_cell, 365 reduction_cell=reduction_cell, 366 num_classes=num_classes, 367 hparams=hparams, 368 is_training=is_training, 369 stem_type='imagenet', 370 final_endpoint=final_endpoint) 371build_nasnet_mobile.default_image_size = 224 372 373 374def build_nasnet_large(images, num_classes, 375 is_training=True, 376 final_endpoint=None): 377 """Build NASNet Large model for the ImageNet Dataset.""" 378 hparams = _large_imagenet_config(is_training=is_training) 379 380 if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': 381 tf.logging.info('A GPU is available on the machine, consider using NCHW ' 382 'data format for increased speed on GPU.') 383 384 if hparams.data_format == 'NCHW': 385 images = tf.transpose(images, [0, 3, 1, 2]) 386 387 # Calculate the total number of cells in the network 388 # Add 2 for the reduction cells 389 total_num_cells = hparams.num_cells + 2 390 # If ImageNet, then add an additional two for the stem cells 391 total_num_cells += 2 392 393 normal_cell = nasnet_utils.NasNetANormalCell( 394 hparams.num_conv_filters, hparams.drop_path_keep_prob, 395 total_num_cells, hparams.total_training_steps) 396 reduction_cell = nasnet_utils.NasNetAReductionCell( 397 hparams.num_conv_filters, hparams.drop_path_keep_prob, 398 total_num_cells, hparams.total_training_steps) 399 with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], 400 is_training=is_training): 401 with arg_scope([slim.avg_pool2d, 402 slim.max_pool2d, 403 slim.conv2d, 404 slim.batch_norm, 405 slim.separable_conv2d, 406 nasnet_utils.factorized_reduction, 407 nasnet_utils.global_avg_pool, 408 nasnet_utils.get_channel_index, 409 nasnet_utils.get_channel_dim], 410 data_format=hparams.data_format): 411 return _build_nasnet_base(images, 412 normal_cell=normal_cell, 413 reduction_cell=reduction_cell, 414 num_classes=num_classes, 415 hparams=hparams, 416 is_training=is_training, 417 stem_type='imagenet', 418 final_endpoint=final_endpoint) 419build_nasnet_large.default_image_size = 331 420 421 422def _build_nasnet_base(images, 423 normal_cell, 424 reduction_cell, 425 num_classes, 426 hparams, 427 is_training, 428 stem_type, 429 final_endpoint=None): 430 """Constructs a NASNet image model.""" 431 432 end_points = {} 433 def add_and_check_endpoint(endpoint_name, net): 434 end_points[endpoint_name] = net 435 return final_endpoint and (endpoint_name == final_endpoint) 436 437 # Find where to place the reduction cells or stride normal cells 438 reduction_indices = nasnet_utils.calc_reduction_layers( 439 hparams.num_cells, hparams.num_reduction_layers) 440 stem_cell = reduction_cell 441 442 if stem_type == 'imagenet': 443 stem = lambda: _imagenet_stem(images, hparams, stem_cell) 444 elif stem_type == 'cifar': 445 stem = lambda: _cifar_stem(images, hparams) 446 else: 447 raise ValueError('Unknown stem_type: ', stem_type) 448 net, cell_outputs = stem() 449 if add_and_check_endpoint('Stem', net): return net, end_points 450 451 # Setup for building in the auxiliary head. 452 aux_head_cell_idxes = [] 453 if len(reduction_indices) >= 2: 454 aux_head_cell_idxes.append(reduction_indices[1] - 1) 455 456 # Run the cells 457 filter_scaling = 1.0 458 # true_cell_num accounts for the stem cells 459 true_cell_num = 2 if stem_type == 'imagenet' else 0 460 for cell_num in range(hparams.num_cells): 461 stride = 1 462 if hparams.skip_reduction_layer_input: 463 prev_layer = cell_outputs[-2] 464 if cell_num in reduction_indices: 465 filter_scaling *= hparams.filter_scaling_rate 466 net = reduction_cell( 467 net, 468 scope='reduction_cell_{}'.format(reduction_indices.index(cell_num)), 469 filter_scaling=filter_scaling, 470 stride=2, 471 prev_layer=cell_outputs[-2], 472 cell_num=true_cell_num) 473 if add_and_check_endpoint( 474 'Reduction_Cell_{}'.format(reduction_indices.index(cell_num)), net): 475 return net, end_points 476 true_cell_num += 1 477 cell_outputs.append(net) 478 if not hparams.skip_reduction_layer_input: 479 prev_layer = cell_outputs[-2] 480 net = normal_cell( 481 net, 482 scope='cell_{}'.format(cell_num), 483 filter_scaling=filter_scaling, 484 stride=stride, 485 prev_layer=prev_layer, 486 cell_num=true_cell_num) 487 488 if add_and_check_endpoint('Cell_{}'.format(cell_num), net): 489 return net, end_points 490 true_cell_num += 1 491 if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and 492 num_classes and is_training): 493 aux_net = tf.nn.relu(net) 494 _build_aux_head(aux_net, end_points, num_classes, hparams, 495 scope='aux_{}'.format(cell_num)) 496 cell_outputs.append(net) 497 498 # Final softmax layer 499 with tf.variable_scope('final_layer'): 500 net = tf.nn.relu(net) 501 net = nasnet_utils.global_avg_pool(net) 502 if add_and_check_endpoint('global_pool', net) or num_classes is None: 503 return net, end_points 504 net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout') 505 logits = slim.fully_connected(net, num_classes) 506 507 if add_and_check_endpoint('Logits', logits): 508 return net, end_points 509 510 predictions = tf.nn.softmax(logits, name='predictions') 511 if add_and_check_endpoint('Predictions', predictions): 512 return net, end_points 513 return logits, end_points 514