1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A custom module for some common operations used by NASNet.
16
17Functions exposed in this file:
18- calc_reduction_layers
19- get_channel_index
20- get_channel_dim
21- global_avg_pool
22- factorized_reduction
23- drop_path
24
25Classes exposed in this file:
26- NasNetABaseCell
27- NasNetANormalCell
28- NasNetAReductionCell
29"""
30from __future__ import absolute_import
31from __future__ import division
32from __future__ import print_function
33
34import tensorflow as tf
35
36
37arg_scope = tf.contrib.framework.arg_scope
38slim = tf.contrib.slim
39
40DATA_FORMAT_NCHW = 'NCHW'
41DATA_FORMAT_NHWC = 'NHWC'
42INVALID = 'null'
43
44
45def calc_reduction_layers(num_cells, num_reduction_layers):
46  """Figure out what layers should have reductions."""
47  reduction_layers = []
48  for pool_num in range(1, num_reduction_layers + 1):
49    layer_num = (float(pool_num) / (num_reduction_layers + 1)) * num_cells
50    layer_num = int(layer_num)
51    reduction_layers.append(layer_num)
52  return reduction_layers
53
54
55@tf.contrib.framework.add_arg_scope
56def get_channel_index(data_format=INVALID):
57  assert data_format != INVALID
58  axis = 3 if data_format == 'NHWC' else 1
59  return axis
60
61
62@tf.contrib.framework.add_arg_scope
63def get_channel_dim(shape, data_format=INVALID):
64  assert data_format != INVALID
65  assert len(shape) == 4
66  if data_format == 'NHWC':
67    return int(shape[3])
68  elif data_format == 'NCHW':
69    return int(shape[1])
70  else:
71    raise ValueError('Not a valid data_format', data_format)
72
73
74@tf.contrib.framework.add_arg_scope
75def global_avg_pool(x, data_format=INVALID):
76  """Average pool away the height and width spatial dimensions of x."""
77  assert data_format != INVALID
78  assert data_format in ['NHWC', 'NCHW']
79  assert x.shape.ndims == 4
80  if data_format == 'NHWC':
81    return tf.reduce_mean(x, [1, 2])
82  else:
83    return tf.reduce_mean(x, [2, 3])
84
85
86@tf.contrib.framework.add_arg_scope
87def factorized_reduction(net, output_filters, stride, data_format=INVALID):
88  """Reduces the shape of net without information loss due to striding."""
89  assert output_filters % 2 == 0, (
90      'Need even number of filters when using this factorized reduction.')
91  assert data_format != INVALID
92  if stride == 1:
93    net = slim.conv2d(net, output_filters, 1, scope='path_conv')
94    net = slim.batch_norm(net, scope='path_bn')
95    return net
96  if data_format == 'NHWC':
97    stride_spec = [1, stride, stride, 1]
98  else:
99    stride_spec = [1, 1, stride, stride]
100
101  # Skip path 1
102  path1 = tf.nn.avg_pool(
103      net, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format)
104  path1 = slim.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv')
105
106  # Skip path 2
107  # First pad with 0's on the right and bottom, then shift the filter to
108  # include those 0's that were added.
109  if data_format == 'NHWC':
110    pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
111    path2 = tf.pad(net, pad_arr)[:, 1:, 1:, :]
112    concat_axis = 3
113  else:
114    pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]]
115    path2 = tf.pad(net, pad_arr)[:, :, 1:, 1:]
116    concat_axis = 1
117
118  path2 = tf.nn.avg_pool(
119      path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format)
120  path2 = slim.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv')
121
122  # Concat and apply BN
123  final_path = tf.concat(values=[path1, path2], axis=concat_axis)
124  final_path = slim.batch_norm(final_path, scope='final_path_bn')
125  return final_path
126
127
128@tf.contrib.framework.add_arg_scope
129def drop_path(net, keep_prob, is_training=True):
130  """Drops out a whole example hiddenstate with the specified probability."""
131  if is_training:
132    batch_size = tf.shape(net)[0]
133    noise_shape = [batch_size, 1, 1, 1]
134    random_tensor = keep_prob
135    random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32)
136    binary_tensor = tf.floor(random_tensor)
137    net = tf.div(net, keep_prob) * binary_tensor
138  return net
139
140
141def _operation_to_filter_shape(operation):
142  splitted_operation = operation.split('x')
143  filter_shape = int(splitted_operation[0][-1])
144  assert filter_shape == int(
145      splitted_operation[1][0]), 'Rectangular filters not supported.'
146  return filter_shape
147
148
149def _operation_to_num_layers(operation):
150  splitted_operation = operation.split('_')
151  if 'x' in splitted_operation[-1]:
152    return 1
153  return int(splitted_operation[-1])
154
155
156def _operation_to_info(operation):
157  """Takes in operation name and returns meta information.
158
159  An example would be 'separable_3x3_4' -> (3, 4).
160
161  Args:
162    operation: String that corresponds to convolution operation.
163
164  Returns:
165    Tuple of (filter shape, num layers).
166  """
167  num_layers = _operation_to_num_layers(operation)
168  filter_shape = _operation_to_filter_shape(operation)
169  return num_layers, filter_shape
170
171
172def _stacked_separable_conv(net, stride, operation, filter_size):
173  """Takes in an operations and parses it to the correct sep operation."""
174  num_layers, kernel_size = _operation_to_info(operation)
175  for layer_num in range(num_layers - 1):
176    net = tf.nn.relu(net)
177    net = slim.separable_conv2d(
178        net,
179        filter_size,
180        kernel_size,
181        depth_multiplier=1,
182        scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
183        stride=stride)
184    net = slim.batch_norm(
185        net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
186    stride = 1
187  net = tf.nn.relu(net)
188  net = slim.separable_conv2d(
189      net,
190      filter_size,
191      kernel_size,
192      depth_multiplier=1,
193      scope='separable_{0}x{0}_{1}'.format(kernel_size, num_layers),
194      stride=stride)
195  net = slim.batch_norm(
196      net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, num_layers))
197  return net
198
199
200def _operation_to_pooling_type(operation):
201  """Takes in the operation string and returns the pooling type."""
202  splitted_operation = operation.split('_')
203  return splitted_operation[0]
204
205
206def _operation_to_pooling_shape(operation):
207  """Takes in the operation string and returns the pooling kernel shape."""
208  splitted_operation = operation.split('_')
209  shape = splitted_operation[-1]
210  assert 'x' in shape
211  filter_height, filter_width = shape.split('x')
212  assert filter_height == filter_width
213  return int(filter_height)
214
215
216def _operation_to_pooling_info(operation):
217  """Parses the pooling operation string to return its type and shape."""
218  pooling_type = _operation_to_pooling_type(operation)
219  pooling_shape = _operation_to_pooling_shape(operation)
220  return pooling_type, pooling_shape
221
222
223def _pooling(net, stride, operation):
224  """Parses operation and performs the correct pooling operation on net."""
225  padding = 'SAME'
226  pooling_type, pooling_shape = _operation_to_pooling_info(operation)
227  if pooling_type == 'avg':
228    net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding=padding)
229  elif pooling_type == 'max':
230    net = slim.max_pool2d(net, pooling_shape, stride=stride, padding=padding)
231  else:
232    raise NotImplementedError('Unimplemented pooling type: ', pooling_type)
233  return net
234
235
236class NasNetABaseCell(object):
237  """NASNet Cell class that is used as a 'layer' in image architectures.
238
239  Args:
240    num_conv_filters: The number of filters for each convolution operation.
241    operations: List of operations that are performed in the NASNet Cell in
242      order.
243    used_hiddenstates: Binary array that signals if the hiddenstate was used
244      within the cell. This is used to determine what outputs of the cell
245      should be concatenated together.
246    hiddenstate_indices: Determines what hiddenstates should be combined
247      together with the specified operations to create the NASNet cell.
248  """
249
250  def __init__(self, num_conv_filters, operations, used_hiddenstates,
251               hiddenstate_indices, drop_path_keep_prob, total_num_cells,
252               total_training_steps):
253    self._num_conv_filters = num_conv_filters
254    self._operations = operations
255    self._used_hiddenstates = used_hiddenstates
256    self._hiddenstate_indices = hiddenstate_indices
257    self._drop_path_keep_prob = drop_path_keep_prob
258    self._total_num_cells = total_num_cells
259    self._total_training_steps = total_training_steps
260
261  def _reduce_prev_layer(self, prev_layer, curr_layer):
262    """Matches dimension of prev_layer to the curr_layer."""
263    # Set the prev layer to the current layer if it is none
264    if prev_layer is None:
265      return curr_layer
266    curr_num_filters = self._filter_size
267    prev_num_filters = get_channel_dim(prev_layer.shape)
268    curr_filter_shape = int(curr_layer.shape[2])
269    prev_filter_shape = int(prev_layer.shape[2])
270    if curr_filter_shape != prev_filter_shape:
271      prev_layer = tf.nn.relu(prev_layer)
272      prev_layer = factorized_reduction(
273          prev_layer, curr_num_filters, stride=2)
274    elif curr_num_filters != prev_num_filters:
275      prev_layer = tf.nn.relu(prev_layer)
276      prev_layer = slim.conv2d(
277          prev_layer, curr_num_filters, 1, scope='prev_1x1')
278      prev_layer = slim.batch_norm(prev_layer, scope='prev_bn')
279    return prev_layer
280
281  def _cell_base(self, net, prev_layer):
282    """Runs the beginning of the conv cell before the predicted ops are run."""
283    num_filters = self._filter_size
284
285    # Check to be sure prev layer stuff is setup correctly
286    prev_layer = self._reduce_prev_layer(prev_layer, net)
287
288    net = tf.nn.relu(net)
289    net = slim.conv2d(net, num_filters, 1, scope='1x1')
290    net = slim.batch_norm(net, scope='beginning_bn')
291    split_axis = get_channel_index()
292    net = tf.split(
293        axis=split_axis, num_or_size_splits=1, value=net)
294    for split in net:
295      assert int(split.shape[split_axis] == int(self._num_conv_filters *
296                                                self._filter_scaling))
297    net.append(prev_layer)
298    return net
299
300  def __call__(self, net, scope=None, filter_scaling=1, stride=1,
301               prev_layer=None, cell_num=-1):
302    """Runs the conv cell."""
303    self._cell_num = cell_num
304    self._filter_scaling = filter_scaling
305    self._filter_size = int(self._num_conv_filters * filter_scaling)
306
307    i = 0
308    with tf.variable_scope(scope):
309      net = self._cell_base(net, prev_layer)
310      for iteration in range(5):
311        with tf.variable_scope('comb_iter_{}'.format(iteration)):
312          left_hiddenstate_idx, right_hiddenstate_idx = (
313              self._hiddenstate_indices[i],
314              self._hiddenstate_indices[i + 1])
315          original_input_left = left_hiddenstate_idx < 2
316          original_input_right = right_hiddenstate_idx < 2
317          h1 = net[left_hiddenstate_idx]
318          h2 = net[right_hiddenstate_idx]
319
320          operation_left = self._operations[i]
321          operation_right = self._operations[i+1]
322          i += 2
323          # Apply conv operations
324          with tf.variable_scope('left'):
325            h1 = self._apply_conv_operation(h1, operation_left,
326                                            stride, original_input_left)
327          with tf.variable_scope('right'):
328            h2 = self._apply_conv_operation(h2, operation_right,
329                                            stride, original_input_right)
330
331          # Combine hidden states using 'add'.
332          with tf.variable_scope('combine'):
333            h = h1 + h2
334
335          # Add hiddenstate to the list of hiddenstates we can choose from
336          net.append(h)
337
338      with tf.variable_scope('cell_output'):
339        net = self._combine_unused_states(net)
340
341      return net
342
343  def _apply_conv_operation(self, net, operation,
344                            stride, is_from_original_input):
345    """Applies the predicted conv operation to net."""
346    # Dont stride if this is not one of the original hiddenstates
347    if stride > 1 and not is_from_original_input:
348      stride = 1
349    input_filters = get_channel_dim(net.shape)
350    filter_size = self._filter_size
351    if 'separable' in operation:
352      net = _stacked_separable_conv(net, stride, operation, filter_size)
353    elif operation in ['none']:
354      # Check if a stride is needed, then use a strided 1x1 here
355      if stride > 1 or (input_filters != filter_size):
356        net = tf.nn.relu(net)
357        net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
358        net = slim.batch_norm(net, scope='bn_1')
359    elif 'pool' in operation:
360      net = _pooling(net, stride, operation)
361      if input_filters != filter_size:
362        net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
363        net = slim.batch_norm(net, scope='bn_1')
364    else:
365      raise ValueError('Unimplemented operation', operation)
366
367    if operation != 'none':
368      net = self._apply_drop_path(net)
369    return net
370
371  def _combine_unused_states(self, net):
372    """Concatenate the unused hidden states of the cell."""
373    used_hiddenstates = self._used_hiddenstates
374
375    final_height = int(net[-1].shape[2])
376    final_num_filters = get_channel_dim(net[-1].shape)
377    assert len(used_hiddenstates) == len(net)
378    for idx, used_h in enumerate(used_hiddenstates):
379      curr_height = int(net[idx].shape[2])
380      curr_num_filters = get_channel_dim(net[idx].shape)
381
382      # Determine if a reduction should be applied to make the number of
383      # filters match.
384      should_reduce = final_num_filters != curr_num_filters
385      should_reduce = (final_height != curr_height) or should_reduce
386      should_reduce = should_reduce and not used_h
387      if should_reduce:
388        stride = 2 if final_height != curr_height else 1
389        with tf.variable_scope('reduction_{}'.format(idx)):
390          net[idx] = factorized_reduction(
391              net[idx], final_num_filters, stride)
392
393    states_to_combine = (
394        [h for h, is_used in zip(net, used_hiddenstates) if not is_used])
395
396    # Return the concat of all the states
397    concat_axis = get_channel_index()
398    net = tf.concat(values=states_to_combine, axis=concat_axis)
399    return net
400
401  def _apply_drop_path(self, net):
402    """Apply drop_path regularization to net."""
403    drop_path_keep_prob = self._drop_path_keep_prob
404    if drop_path_keep_prob < 1.0:
405      # Scale keep prob by layer number
406      assert self._cell_num != -1
407      # The added 2 is for the reduction cells
408      num_cells = self._total_num_cells
409      layer_ratio = (self._cell_num + 1)/float(num_cells)
410      with tf.device('/cpu:0'):
411        tf.summary.scalar('layer_ratio', layer_ratio)
412      drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
413      # Decrease the keep probability over time
414      current_step = tf.cast(tf.train.get_or_create_global_step(),
415                             tf.float32)
416      drop_path_burn_in_steps = self._total_training_steps
417      current_ratio = (
418          current_step / drop_path_burn_in_steps)
419      current_ratio = tf.minimum(1.0, current_ratio)
420      with tf.device('/cpu:0'):
421        tf.summary.scalar('current_ratio', current_ratio)
422      drop_path_keep_prob = (
423          1 - current_ratio * (1 - drop_path_keep_prob))
424      with tf.device('/cpu:0'):
425        tf.summary.scalar('drop_path_keep_prob', drop_path_keep_prob)
426      net = drop_path(net, drop_path_keep_prob)
427    return net
428
429
430class NasNetANormalCell(NasNetABaseCell):
431  """NASNetA Normal Cell."""
432
433  def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
434               total_training_steps):
435    operations = ['separable_5x5_2',
436                  'separable_3x3_2',
437                  'separable_5x5_2',
438                  'separable_3x3_2',
439                  'avg_pool_3x3',
440                  'none',
441                  'avg_pool_3x3',
442                  'avg_pool_3x3',
443                  'separable_3x3_2',
444                  'none']
445    used_hiddenstates = [1, 0, 0, 0, 0, 0, 0]
446    hiddenstate_indices = [0, 1, 1, 1, 0, 1, 1, 1, 0, 0]
447    super(NasNetANormalCell, self).__init__(num_conv_filters, operations,
448                                            used_hiddenstates,
449                                            hiddenstate_indices,
450                                            drop_path_keep_prob,
451                                            total_num_cells,
452                                            total_training_steps)
453
454
455class NasNetAReductionCell(NasNetABaseCell):
456  """NASNetA Reduction Cell."""
457
458  def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
459               total_training_steps):
460    operations = ['separable_5x5_2',
461                  'separable_7x7_2',
462                  'max_pool_3x3',
463                  'separable_7x7_2',
464                  'avg_pool_3x3',
465                  'separable_5x5_2',
466                  'none',
467                  'avg_pool_3x3',
468                  'separable_3x3_2',
469                  'max_pool_3x3']
470    used_hiddenstates = [1, 1, 1, 0, 0, 0, 0]
471    hiddenstate_indices = [0, 1, 0, 1, 0, 1, 3, 2, 2, 0]
472    super(NasNetAReductionCell, self).__init__(num_conv_filters, operations,
473                                               used_hiddenstates,
474                                               hiddenstate_indices,
475                                               drop_path_keep_prob,
476                                               total_num_cells,
477                                               total_training_steps)
478