1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the definition for the NASNet classification networks.
16
17Paper: https://arxiv.org/abs/1707.07012
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import tensorflow as tf
24
25from . import nasnet_utils
26
27arg_scope = tf.contrib.framework.arg_scope
28slim = tf.contrib.slim
29
30
31# Notes for training NASNet Cifar Model
32# -------------------------------------
33# batch_size: 32
34# learning rate: 0.025
35# cosine (single period) learning rate decay
36# auxiliary head loss weighting: 0.4
37# clip global norm of all gradients by 5
38def _cifar_config(is_training=True):
39  drop_path_keep_prob = 1.0 if not is_training else 0.6
40  return tf.contrib.training.HParams(
41      stem_multiplier=3.0,
42      drop_path_keep_prob=drop_path_keep_prob,
43      num_cells=18,
44      use_aux_head=1,
45      num_conv_filters=32,
46      dense_dropout_keep_prob=1.0,
47      filter_scaling_rate=2.0,
48      num_reduction_layers=2,
49      data_format='NHWC',
50      skip_reduction_layer_input=0,
51      # 600 epochs with a batch size of 32
52      # This is used for the drop path probabilities since it needs to increase
53      # the drop out probability over the course of training.
54      total_training_steps=937500,
55  )
56
57
58# Notes for training large NASNet model on ImageNet
59# -------------------------------------
60# batch size (per replica): 16
61# learning rate: 0.015 * 100
62# learning rate decay factor: 0.97
63# num epochs per decay: 2.4
64# sync sgd with 100 replicas
65# auxiliary head loss weighting: 0.4
66# label smoothing: 0.1
67# clip global norm of all gradients by 10
68def _large_imagenet_config(is_training=True):
69  drop_path_keep_prob = 1.0 if not is_training else 0.7
70  return tf.contrib.training.HParams(
71      stem_multiplier=3.0,
72      dense_dropout_keep_prob=0.5,
73      num_cells=18,
74      filter_scaling_rate=2.0,
75      num_conv_filters=168,
76      drop_path_keep_prob=drop_path_keep_prob,
77      use_aux_head=1,
78      num_reduction_layers=2,
79      data_format='NHWC',
80      skip_reduction_layer_input=1,
81      total_training_steps=250000,
82  )
83
84
85# Notes for training the mobile NASNet ImageNet model
86# -------------------------------------
87# batch size (per replica): 32
88# learning rate: 0.04 * 50
89# learning rate scaling factor: 0.97
90# num epochs per decay: 2.4
91# sync sgd with 50 replicas
92# auxiliary head weighting: 0.4
93# label smoothing: 0.1
94# clip global norm of all gradients by 10
95def _mobile_imagenet_config():
96  return tf.contrib.training.HParams(
97      stem_multiplier=1.0,
98      dense_dropout_keep_prob=0.5,
99      num_cells=12,
100      filter_scaling_rate=2.0,
101      drop_path_keep_prob=1.0,
102      num_conv_filters=44,
103      use_aux_head=1,
104      num_reduction_layers=2,
105      data_format='NHWC',
106      skip_reduction_layer_input=0,
107      total_training_steps=250000,
108  )
109
110
111def nasnet_cifar_arg_scope(weight_decay=5e-4,
112                           batch_norm_decay=0.9,
113                           batch_norm_epsilon=1e-5):
114  """Defines the default arg scope for the NASNet-A Cifar model.
115
116  Args:
117    weight_decay: The weight decay to use for regularizing the model.
118    batch_norm_decay: Decay for batch norm moving average.
119    batch_norm_epsilon: Small float added to variance to avoid dividing by zero
120      in batch norm.
121
122  Returns:
123    An `arg_scope` to use for the NASNet Cifar Model.
124  """
125  batch_norm_params = {
126      # Decay for the moving averages.
127      'decay': batch_norm_decay,
128      # epsilon to prevent 0s in variance.
129      'epsilon': batch_norm_epsilon,
130      'scale': True,
131      'fused': True,
132  }
133  weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
134  weights_initializer = tf.contrib.layers.variance_scaling_initializer(
135      mode='FAN_OUT')
136  with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
137                 weights_regularizer=weights_regularizer,
138                 weights_initializer=weights_initializer):
139    with arg_scope([slim.fully_connected],
140                   activation_fn=None, scope='FC'):
141      with arg_scope([slim.conv2d, slim.separable_conv2d],
142                     activation_fn=None, biases_initializer=None):
143        with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
144          return sc
145
146
147def nasnet_mobile_arg_scope(weight_decay=4e-5,
148                            batch_norm_decay=0.9997,
149                            batch_norm_epsilon=1e-3):
150  """Defines the default arg scope for the NASNet-A Mobile ImageNet model.
151
152  Args:
153    weight_decay: The weight decay to use for regularizing the model.
154    batch_norm_decay: Decay for batch norm moving average.
155    batch_norm_epsilon: Small float added to variance to avoid dividing by zero
156      in batch norm.
157
158  Returns:
159    An `arg_scope` to use for the NASNet Mobile Model.
160  """
161  batch_norm_params = {
162      # Decay for the moving averages.
163      'decay': batch_norm_decay,
164      # epsilon to prevent 0s in variance.
165      'epsilon': batch_norm_epsilon,
166      'scale': True,
167      'fused': True,
168  }
169  weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
170  weights_initializer = tf.contrib.layers.variance_scaling_initializer(
171      mode='FAN_OUT')
172  with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
173                 weights_regularizer=weights_regularizer,
174                 weights_initializer=weights_initializer):
175    with arg_scope([slim.fully_connected],
176                   activation_fn=None, scope='FC'):
177      with arg_scope([slim.conv2d, slim.separable_conv2d],
178                     activation_fn=None, biases_initializer=None):
179        with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
180          return sc
181
182
183def nasnet_large_arg_scope(weight_decay=5e-5,
184                           batch_norm_decay=0.9997,
185                           batch_norm_epsilon=1e-3):
186  """Defines the default arg scope for the NASNet-A Large ImageNet model.
187
188  Args:
189    weight_decay: The weight decay to use for regularizing the model.
190    batch_norm_decay: Decay for batch norm moving average.
191    batch_norm_epsilon: Small float added to variance to avoid dividing by zero
192      in batch norm.
193
194  Returns:
195    An `arg_scope` to use for the NASNet Large Model.
196  """
197  batch_norm_params = {
198      # Decay for the moving averages.
199      'decay': batch_norm_decay,
200      # epsilon to prevent 0s in variance.
201      'epsilon': batch_norm_epsilon,
202      'scale': True,
203      'fused': True,
204  }
205  weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
206  weights_initializer = tf.contrib.layers.variance_scaling_initializer(
207      mode='FAN_OUT')
208  with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
209                 weights_regularizer=weights_regularizer,
210                 weights_initializer=weights_initializer):
211    with arg_scope([slim.fully_connected],
212                   activation_fn=None, scope='FC'):
213      with arg_scope([slim.conv2d, slim.separable_conv2d],
214                     activation_fn=None, biases_initializer=None):
215        with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
216          return sc
217
218
219def _build_aux_head(net, end_points, num_classes, hparams, scope):
220  """Auxiliary head used for all models across all datasets."""
221  with tf.variable_scope(scope):
222    aux_logits = tf.identity(net)
223    with tf.variable_scope('aux_logits'):
224      aux_logits = slim.avg_pool2d(
225          aux_logits, [5, 5], stride=3, padding='VALID')
226      aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj')
227      aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0')
228      aux_logits = tf.nn.relu(aux_logits)
229      # Shape of feature map before the final layer.
230      shape = aux_logits.shape
231      if hparams.data_format == 'NHWC':
232        shape = shape[1:3]
233      else:
234        shape = shape[2:4]
235      aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID')
236      aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1')
237      aux_logits = tf.nn.relu(aux_logits)
238      aux_logits = tf.contrib.layers.flatten(aux_logits)
239      aux_logits = slim.fully_connected(aux_logits, num_classes)
240      end_points['AuxLogits'] = aux_logits
241
242
243def _imagenet_stem(inputs, hparams, stem_cell):
244  """Stem used for models trained on ImageNet."""
245  num_stem_cells = 2
246
247  # 149 x 149 x 32
248  num_stem_filters = int(32 * hparams.stem_multiplier)
249  net = slim.conv2d(
250      inputs, num_stem_filters, [3, 3], stride=2, scope='conv0',
251      padding='VALID')
252  net = slim.batch_norm(net, scope='conv0_bn')
253
254  # Run the reduction cells
255  cell_outputs = [None, net]
256  filter_scaling = 1.0 / (hparams.filter_scaling_rate**num_stem_cells)
257  for cell_num in range(num_stem_cells):
258    net = stem_cell(
259        net,
260        scope='cell_stem_{}'.format(cell_num),
261        filter_scaling=filter_scaling,
262        stride=2,
263        prev_layer=cell_outputs[-2],
264        cell_num=cell_num)
265    cell_outputs.append(net)
266    filter_scaling *= hparams.filter_scaling_rate
267  return net, cell_outputs
268
269
270def _cifar_stem(inputs, hparams):
271  """Stem used for models trained on Cifar."""
272  num_stem_filters = int(hparams.num_conv_filters * hparams.stem_multiplier)
273  net = slim.conv2d(
274      inputs,
275      num_stem_filters,
276      3,
277      scope='l1_stem_3x3')
278  net = slim.batch_norm(net, scope='l1_stem_bn')
279  return net, [None, net]
280
281
282def build_nasnet_cifar(
283    images, num_classes, is_training=True):
284  """Build NASNet model for the Cifar Dataset."""
285  hparams = _cifar_config(is_training=is_training)
286
287  if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
288    tf.logging.info('A GPU is available on the machine, consider using NCHW '
289                    'data format for increased speed on GPU.')
290
291  if hparams.data_format == 'NCHW':
292    images = tf.transpose(images, [0, 3, 1, 2])
293
294  # Calculate the total number of cells in the network
295  # Add 2 for the reduction cells
296  total_num_cells = hparams.num_cells + 2
297
298  normal_cell = nasnet_utils.NasNetANormalCell(
299      hparams.num_conv_filters, hparams.drop_path_keep_prob,
300      total_num_cells, hparams.total_training_steps)
301  reduction_cell = nasnet_utils.NasNetAReductionCell(
302      hparams.num_conv_filters, hparams.drop_path_keep_prob,
303      total_num_cells, hparams.total_training_steps)
304  with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
305                 is_training=is_training):
306    with arg_scope([slim.avg_pool2d,
307                    slim.max_pool2d,
308                    slim.conv2d,
309                    slim.batch_norm,
310                    slim.separable_conv2d,
311                    nasnet_utils.factorized_reduction,
312                    nasnet_utils.global_avg_pool,
313                    nasnet_utils.get_channel_index,
314                    nasnet_utils.get_channel_dim],
315                   data_format=hparams.data_format):
316      return _build_nasnet_base(images,
317                                normal_cell=normal_cell,
318                                reduction_cell=reduction_cell,
319                                num_classes=num_classes,
320                                hparams=hparams,
321                                is_training=is_training,
322                                stem_type='cifar')
323build_nasnet_cifar.default_image_size = 32
324
325
326def build_nasnet_mobile(images, num_classes,
327                        is_training=True,
328                        final_endpoint=None):
329  """Build NASNet Mobile model for the ImageNet Dataset."""
330  hparams = _mobile_imagenet_config()
331
332  if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
333    tf.logging.info('A GPU is available on the machine, consider using NCHW '
334                    'data format for increased speed on GPU.')
335
336  if hparams.data_format == 'NCHW':
337    images = tf.transpose(images, [0, 3, 1, 2])
338
339  # Calculate the total number of cells in the network
340  # Add 2 for the reduction cells
341  total_num_cells = hparams.num_cells + 2
342  # If ImageNet, then add an additional two for the stem cells
343  total_num_cells += 2
344
345  normal_cell = nasnet_utils.NasNetANormalCell(
346      hparams.num_conv_filters, hparams.drop_path_keep_prob,
347      total_num_cells, hparams.total_training_steps)
348  reduction_cell = nasnet_utils.NasNetAReductionCell(
349      hparams.num_conv_filters, hparams.drop_path_keep_prob,
350      total_num_cells, hparams.total_training_steps)
351  with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
352                 is_training=is_training):
353    with arg_scope([slim.avg_pool2d,
354                    slim.max_pool2d,
355                    slim.conv2d,
356                    slim.batch_norm,
357                    slim.separable_conv2d,
358                    nasnet_utils.factorized_reduction,
359                    nasnet_utils.global_avg_pool,
360                    nasnet_utils.get_channel_index,
361                    nasnet_utils.get_channel_dim],
362                   data_format=hparams.data_format):
363      return _build_nasnet_base(images,
364                                normal_cell=normal_cell,
365                                reduction_cell=reduction_cell,
366                                num_classes=num_classes,
367                                hparams=hparams,
368                                is_training=is_training,
369                                stem_type='imagenet',
370                                final_endpoint=final_endpoint)
371build_nasnet_mobile.default_image_size = 224
372
373
374def build_nasnet_large(images, num_classes,
375                       is_training=True,
376                       final_endpoint=None):
377  """Build NASNet Large model for the ImageNet Dataset."""
378  hparams = _large_imagenet_config(is_training=is_training)
379
380  if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
381    tf.logging.info('A GPU is available on the machine, consider using NCHW '
382                    'data format for increased speed on GPU.')
383
384  if hparams.data_format == 'NCHW':
385    images = tf.transpose(images, [0, 3, 1, 2])
386
387  # Calculate the total number of cells in the network
388  # Add 2 for the reduction cells
389  total_num_cells = hparams.num_cells + 2
390  # If ImageNet, then add an additional two for the stem cells
391  total_num_cells += 2
392
393  normal_cell = nasnet_utils.NasNetANormalCell(
394      hparams.num_conv_filters, hparams.drop_path_keep_prob,
395      total_num_cells, hparams.total_training_steps)
396  reduction_cell = nasnet_utils.NasNetAReductionCell(
397      hparams.num_conv_filters, hparams.drop_path_keep_prob,
398      total_num_cells, hparams.total_training_steps)
399  with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
400                 is_training=is_training):
401    with arg_scope([slim.avg_pool2d,
402                    slim.max_pool2d,
403                    slim.conv2d,
404                    slim.batch_norm,
405                    slim.separable_conv2d,
406                    nasnet_utils.factorized_reduction,
407                    nasnet_utils.global_avg_pool,
408                    nasnet_utils.get_channel_index,
409                    nasnet_utils.get_channel_dim],
410                   data_format=hparams.data_format):
411      return _build_nasnet_base(images,
412                                normal_cell=normal_cell,
413                                reduction_cell=reduction_cell,
414                                num_classes=num_classes,
415                                hparams=hparams,
416                                is_training=is_training,
417                                stem_type='imagenet',
418                                final_endpoint=final_endpoint)
419build_nasnet_large.default_image_size = 331
420
421
422def _build_nasnet_base(images,
423                       normal_cell,
424                       reduction_cell,
425                       num_classes,
426                       hparams,
427                       is_training,
428                       stem_type,
429                       final_endpoint=None):
430  """Constructs a NASNet image model."""
431
432  end_points = {}
433  def add_and_check_endpoint(endpoint_name, net):
434    end_points[endpoint_name] = net
435    return final_endpoint and (endpoint_name == final_endpoint)
436
437  # Find where to place the reduction cells or stride normal cells
438  reduction_indices = nasnet_utils.calc_reduction_layers(
439      hparams.num_cells, hparams.num_reduction_layers)
440  stem_cell = reduction_cell
441
442  if stem_type == 'imagenet':
443    stem = lambda: _imagenet_stem(images, hparams, stem_cell)
444  elif stem_type == 'cifar':
445    stem = lambda: _cifar_stem(images, hparams)
446  else:
447    raise ValueError('Unknown stem_type: ', stem_type)
448  net, cell_outputs = stem()
449  if add_and_check_endpoint('Stem', net): return net, end_points
450
451  # Setup for building in the auxiliary head.
452  aux_head_cell_idxes = []
453  if len(reduction_indices) >= 2:
454    aux_head_cell_idxes.append(reduction_indices[1] - 1)
455
456  # Run the cells
457  filter_scaling = 1.0
458  # true_cell_num accounts for the stem cells
459  true_cell_num = 2 if stem_type == 'imagenet' else 0
460  for cell_num in range(hparams.num_cells):
461    stride = 1
462    if hparams.skip_reduction_layer_input:
463      prev_layer = cell_outputs[-2]
464    if cell_num in reduction_indices:
465      filter_scaling *= hparams.filter_scaling_rate
466      net = reduction_cell(
467          net,
468          scope='reduction_cell_{}'.format(reduction_indices.index(cell_num)),
469          filter_scaling=filter_scaling,
470          stride=2,
471          prev_layer=cell_outputs[-2],
472          cell_num=true_cell_num)
473      if add_and_check_endpoint(
474          'Reduction_Cell_{}'.format(reduction_indices.index(cell_num)), net):
475        return net, end_points
476      true_cell_num += 1
477      cell_outputs.append(net)
478    if not hparams.skip_reduction_layer_input:
479      prev_layer = cell_outputs[-2]
480    net = normal_cell(
481        net,
482        scope='cell_{}'.format(cell_num),
483        filter_scaling=filter_scaling,
484        stride=stride,
485        prev_layer=prev_layer,
486        cell_num=true_cell_num)
487
488    if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
489      return net, end_points
490    true_cell_num += 1
491    if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
492        num_classes and is_training):
493      aux_net = tf.nn.relu(net)
494      _build_aux_head(aux_net, end_points, num_classes, hparams,
495                      scope='aux_{}'.format(cell_num))
496    cell_outputs.append(net)
497
498  # Final softmax layer
499  with tf.variable_scope('final_layer'):
500    net = tf.nn.relu(net)
501    net = nasnet_utils.global_avg_pool(net)
502    if add_and_check_endpoint('global_pool', net) or num_classes is None:
503      return net, end_points
504    net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout')
505    logits = slim.fully_connected(net, num_classes)
506
507    if add_and_check_endpoint('Logits', logits):
508      return net, end_points
509
510    predictions = tf.nn.softmax(logits, name='predictions')
511    if add_and_check_endpoint('Predictions', predictions):
512      return net, end_points
513  return logits, end_points
514