1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A custom module for some common operations used by NASNet. 16 17Functions exposed in this file: 18- calc_reduction_layers 19- get_channel_index 20- get_channel_dim 21- global_avg_pool 22- factorized_reduction 23- drop_path 24 25Classes exposed in this file: 26- NasNetABaseCell 27- NasNetANormalCell 28- NasNetAReductionCell 29""" 30from __future__ import absolute_import 31from __future__ import division 32from __future__ import print_function 33 34import tensorflow as tf 35 36 37arg_scope = tf.contrib.framework.arg_scope 38slim = tf.contrib.slim 39 40DATA_FORMAT_NCHW = 'NCHW' 41DATA_FORMAT_NHWC = 'NHWC' 42INVALID = 'null' 43 44 45def calc_reduction_layers(num_cells, num_reduction_layers): 46 """Figure out what layers should have reductions.""" 47 reduction_layers = [] 48 for pool_num in range(1, num_reduction_layers + 1): 49 layer_num = (float(pool_num) / (num_reduction_layers + 1)) * num_cells 50 layer_num = int(layer_num) 51 reduction_layers.append(layer_num) 52 return reduction_layers 53 54 55@tf.contrib.framework.add_arg_scope 56def get_channel_index(data_format=INVALID): 57 assert data_format != INVALID 58 axis = 3 if data_format == 'NHWC' else 1 59 return axis 60 61 62@tf.contrib.framework.add_arg_scope 63def get_channel_dim(shape, data_format=INVALID): 64 assert data_format != INVALID 65 assert len(shape) == 4 66 if data_format == 'NHWC': 67 return int(shape[3]) 68 elif data_format == 'NCHW': 69 return int(shape[1]) 70 else: 71 raise ValueError('Not a valid data_format', data_format) 72 73 74@tf.contrib.framework.add_arg_scope 75def global_avg_pool(x, data_format=INVALID): 76 """Average pool away the height and width spatial dimensions of x.""" 77 assert data_format != INVALID 78 assert data_format in ['NHWC', 'NCHW'] 79 assert x.shape.ndims == 4 80 if data_format == 'NHWC': 81 return tf.reduce_mean(x, [1, 2]) 82 else: 83 return tf.reduce_mean(x, [2, 3]) 84 85 86@tf.contrib.framework.add_arg_scope 87def factorized_reduction(net, output_filters, stride, data_format=INVALID): 88 """Reduces the shape of net without information loss due to striding.""" 89 assert output_filters % 2 == 0, ( 90 'Need even number of filters when using this factorized reduction.') 91 assert data_format != INVALID 92 if stride == 1: 93 net = slim.conv2d(net, output_filters, 1, scope='path_conv') 94 net = slim.batch_norm(net, scope='path_bn') 95 return net 96 if data_format == 'NHWC': 97 stride_spec = [1, stride, stride, 1] 98 else: 99 stride_spec = [1, 1, stride, stride] 100 101 # Skip path 1 102 path1 = tf.nn.avg_pool( 103 net, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format) 104 path1 = slim.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv') 105 106 # Skip path 2 107 # First pad with 0's on the right and bottom, then shift the filter to 108 # include those 0's that were added. 109 if data_format == 'NHWC': 110 pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] 111 path2 = tf.pad(net, pad_arr)[:, 1:, 1:, :] 112 concat_axis = 3 113 else: 114 pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]] 115 path2 = tf.pad(net, pad_arr)[:, :, 1:, 1:] 116 concat_axis = 1 117 118 path2 = tf.nn.avg_pool( 119 path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format) 120 path2 = slim.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv') 121 122 # Concat and apply BN 123 final_path = tf.concat(values=[path1, path2], axis=concat_axis) 124 final_path = slim.batch_norm(final_path, scope='final_path_bn') 125 return final_path 126 127 128@tf.contrib.framework.add_arg_scope 129def drop_path(net, keep_prob, is_training=True): 130 """Drops out a whole example hiddenstate with the specified probability.""" 131 if is_training: 132 batch_size = tf.shape(net)[0] 133 noise_shape = [batch_size, 1, 1, 1] 134 random_tensor = keep_prob 135 random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32) 136 binary_tensor = tf.floor(random_tensor) 137 net = tf.div(net, keep_prob) * binary_tensor 138 return net 139 140 141def _operation_to_filter_shape(operation): 142 splitted_operation = operation.split('x') 143 filter_shape = int(splitted_operation[0][-1]) 144 assert filter_shape == int( 145 splitted_operation[1][0]), 'Rectangular filters not supported.' 146 return filter_shape 147 148 149def _operation_to_num_layers(operation): 150 splitted_operation = operation.split('_') 151 if 'x' in splitted_operation[-1]: 152 return 1 153 return int(splitted_operation[-1]) 154 155 156def _operation_to_info(operation): 157 """Takes in operation name and returns meta information. 158 159 An example would be 'separable_3x3_4' -> (3, 4). 160 161 Args: 162 operation: String that corresponds to convolution operation. 163 164 Returns: 165 Tuple of (filter shape, num layers). 166 """ 167 num_layers = _operation_to_num_layers(operation) 168 filter_shape = _operation_to_filter_shape(operation) 169 return num_layers, filter_shape 170 171 172def _stacked_separable_conv(net, stride, operation, filter_size): 173 """Takes in an operations and parses it to the correct sep operation.""" 174 num_layers, kernel_size = _operation_to_info(operation) 175 for layer_num in range(num_layers - 1): 176 net = tf.nn.relu(net) 177 net = slim.separable_conv2d( 178 net, 179 filter_size, 180 kernel_size, 181 depth_multiplier=1, 182 scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1), 183 stride=stride) 184 net = slim.batch_norm( 185 net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1)) 186 stride = 1 187 net = tf.nn.relu(net) 188 net = slim.separable_conv2d( 189 net, 190 filter_size, 191 kernel_size, 192 depth_multiplier=1, 193 scope='separable_{0}x{0}_{1}'.format(kernel_size, num_layers), 194 stride=stride) 195 net = slim.batch_norm( 196 net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, num_layers)) 197 return net 198 199 200def _operation_to_pooling_type(operation): 201 """Takes in the operation string and returns the pooling type.""" 202 splitted_operation = operation.split('_') 203 return splitted_operation[0] 204 205 206def _operation_to_pooling_shape(operation): 207 """Takes in the operation string and returns the pooling kernel shape.""" 208 splitted_operation = operation.split('_') 209 shape = splitted_operation[-1] 210 assert 'x' in shape 211 filter_height, filter_width = shape.split('x') 212 assert filter_height == filter_width 213 return int(filter_height) 214 215 216def _operation_to_pooling_info(operation): 217 """Parses the pooling operation string to return its type and shape.""" 218 pooling_type = _operation_to_pooling_type(operation) 219 pooling_shape = _operation_to_pooling_shape(operation) 220 return pooling_type, pooling_shape 221 222 223def _pooling(net, stride, operation): 224 """Parses operation and performs the correct pooling operation on net.""" 225 padding = 'SAME' 226 pooling_type, pooling_shape = _operation_to_pooling_info(operation) 227 if pooling_type == 'avg': 228 net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding=padding) 229 elif pooling_type == 'max': 230 net = slim.max_pool2d(net, pooling_shape, stride=stride, padding=padding) 231 else: 232 raise NotImplementedError('Unimplemented pooling type: ', pooling_type) 233 return net 234 235 236class NasNetABaseCell(object): 237 """NASNet Cell class that is used as a 'layer' in image architectures. 238 239 Args: 240 num_conv_filters: The number of filters for each convolution operation. 241 operations: List of operations that are performed in the NASNet Cell in 242 order. 243 used_hiddenstates: Binary array that signals if the hiddenstate was used 244 within the cell. This is used to determine what outputs of the cell 245 should be concatenated together. 246 hiddenstate_indices: Determines what hiddenstates should be combined 247 together with the specified operations to create the NASNet cell. 248 """ 249 250 def __init__(self, num_conv_filters, operations, used_hiddenstates, 251 hiddenstate_indices, drop_path_keep_prob, total_num_cells, 252 total_training_steps): 253 self._num_conv_filters = num_conv_filters 254 self._operations = operations 255 self._used_hiddenstates = used_hiddenstates 256 self._hiddenstate_indices = hiddenstate_indices 257 self._drop_path_keep_prob = drop_path_keep_prob 258 self._total_num_cells = total_num_cells 259 self._total_training_steps = total_training_steps 260 261 def _reduce_prev_layer(self, prev_layer, curr_layer): 262 """Matches dimension of prev_layer to the curr_layer.""" 263 # Set the prev layer to the current layer if it is none 264 if prev_layer is None: 265 return curr_layer 266 curr_num_filters = self._filter_size 267 prev_num_filters = get_channel_dim(prev_layer.shape) 268 curr_filter_shape = int(curr_layer.shape[2]) 269 prev_filter_shape = int(prev_layer.shape[2]) 270 if curr_filter_shape != prev_filter_shape: 271 prev_layer = tf.nn.relu(prev_layer) 272 prev_layer = factorized_reduction( 273 prev_layer, curr_num_filters, stride=2) 274 elif curr_num_filters != prev_num_filters: 275 prev_layer = tf.nn.relu(prev_layer) 276 prev_layer = slim.conv2d( 277 prev_layer, curr_num_filters, 1, scope='prev_1x1') 278 prev_layer = slim.batch_norm(prev_layer, scope='prev_bn') 279 return prev_layer 280 281 def _cell_base(self, net, prev_layer): 282 """Runs the beginning of the conv cell before the predicted ops are run.""" 283 num_filters = self._filter_size 284 285 # Check to be sure prev layer stuff is setup correctly 286 prev_layer = self._reduce_prev_layer(prev_layer, net) 287 288 net = tf.nn.relu(net) 289 net = slim.conv2d(net, num_filters, 1, scope='1x1') 290 net = slim.batch_norm(net, scope='beginning_bn') 291 split_axis = get_channel_index() 292 net = tf.split( 293 axis=split_axis, num_or_size_splits=1, value=net) 294 for split in net: 295 assert int(split.shape[split_axis] == int(self._num_conv_filters * 296 self._filter_scaling)) 297 net.append(prev_layer) 298 return net 299 300 def __call__(self, net, scope=None, filter_scaling=1, stride=1, 301 prev_layer=None, cell_num=-1): 302 """Runs the conv cell.""" 303 self._cell_num = cell_num 304 self._filter_scaling = filter_scaling 305 self._filter_size = int(self._num_conv_filters * filter_scaling) 306 307 i = 0 308 with tf.variable_scope(scope): 309 net = self._cell_base(net, prev_layer) 310 for iteration in range(5): 311 with tf.variable_scope('comb_iter_{}'.format(iteration)): 312 left_hiddenstate_idx, right_hiddenstate_idx = ( 313 self._hiddenstate_indices[i], 314 self._hiddenstate_indices[i + 1]) 315 original_input_left = left_hiddenstate_idx < 2 316 original_input_right = right_hiddenstate_idx < 2 317 h1 = net[left_hiddenstate_idx] 318 h2 = net[right_hiddenstate_idx] 319 320 operation_left = self._operations[i] 321 operation_right = self._operations[i+1] 322 i += 2 323 # Apply conv operations 324 with tf.variable_scope('left'): 325 h1 = self._apply_conv_operation(h1, operation_left, 326 stride, original_input_left) 327 with tf.variable_scope('right'): 328 h2 = self._apply_conv_operation(h2, operation_right, 329 stride, original_input_right) 330 331 # Combine hidden states using 'add'. 332 with tf.variable_scope('combine'): 333 h = h1 + h2 334 335 # Add hiddenstate to the list of hiddenstates we can choose from 336 net.append(h) 337 338 with tf.variable_scope('cell_output'): 339 net = self._combine_unused_states(net) 340 341 return net 342 343 def _apply_conv_operation(self, net, operation, 344 stride, is_from_original_input): 345 """Applies the predicted conv operation to net.""" 346 # Dont stride if this is not one of the original hiddenstates 347 if stride > 1 and not is_from_original_input: 348 stride = 1 349 input_filters = get_channel_dim(net.shape) 350 filter_size = self._filter_size 351 if 'separable' in operation: 352 net = _stacked_separable_conv(net, stride, operation, filter_size) 353 elif operation in ['none']: 354 # Check if a stride is needed, then use a strided 1x1 here 355 if stride > 1 or (input_filters != filter_size): 356 net = tf.nn.relu(net) 357 net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1') 358 net = slim.batch_norm(net, scope='bn_1') 359 elif 'pool' in operation: 360 net = _pooling(net, stride, operation) 361 if input_filters != filter_size: 362 net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1') 363 net = slim.batch_norm(net, scope='bn_1') 364 else: 365 raise ValueError('Unimplemented operation', operation) 366 367 if operation != 'none': 368 net = self._apply_drop_path(net) 369 return net 370 371 def _combine_unused_states(self, net): 372 """Concatenate the unused hidden states of the cell.""" 373 used_hiddenstates = self._used_hiddenstates 374 375 final_height = int(net[-1].shape[2]) 376 final_num_filters = get_channel_dim(net[-1].shape) 377 assert len(used_hiddenstates) == len(net) 378 for idx, used_h in enumerate(used_hiddenstates): 379 curr_height = int(net[idx].shape[2]) 380 curr_num_filters = get_channel_dim(net[idx].shape) 381 382 # Determine if a reduction should be applied to make the number of 383 # filters match. 384 should_reduce = final_num_filters != curr_num_filters 385 should_reduce = (final_height != curr_height) or should_reduce 386 should_reduce = should_reduce and not used_h 387 if should_reduce: 388 stride = 2 if final_height != curr_height else 1 389 with tf.variable_scope('reduction_{}'.format(idx)): 390 net[idx] = factorized_reduction( 391 net[idx], final_num_filters, stride) 392 393 states_to_combine = ( 394 [h for h, is_used in zip(net, used_hiddenstates) if not is_used]) 395 396 # Return the concat of all the states 397 concat_axis = get_channel_index() 398 net = tf.concat(values=states_to_combine, axis=concat_axis) 399 return net 400 401 def _apply_drop_path(self, net): 402 """Apply drop_path regularization to net.""" 403 drop_path_keep_prob = self._drop_path_keep_prob 404 if drop_path_keep_prob < 1.0: 405 # Scale keep prob by layer number 406 assert self._cell_num != -1 407 # The added 2 is for the reduction cells 408 num_cells = self._total_num_cells 409 layer_ratio = (self._cell_num + 1)/float(num_cells) 410 with tf.device('/cpu:0'): 411 tf.summary.scalar('layer_ratio', layer_ratio) 412 drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob) 413 # Decrease the keep probability over time 414 current_step = tf.cast(tf.train.get_or_create_global_step(), 415 tf.float32) 416 drop_path_burn_in_steps = self._total_training_steps 417 current_ratio = ( 418 current_step / drop_path_burn_in_steps) 419 current_ratio = tf.minimum(1.0, current_ratio) 420 with tf.device('/cpu:0'): 421 tf.summary.scalar('current_ratio', current_ratio) 422 drop_path_keep_prob = ( 423 1 - current_ratio * (1 - drop_path_keep_prob)) 424 with tf.device('/cpu:0'): 425 tf.summary.scalar('drop_path_keep_prob', drop_path_keep_prob) 426 net = drop_path(net, drop_path_keep_prob) 427 return net 428 429 430class NasNetANormalCell(NasNetABaseCell): 431 """NASNetA Normal Cell.""" 432 433 def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells, 434 total_training_steps): 435 operations = ['separable_5x5_2', 436 'separable_3x3_2', 437 'separable_5x5_2', 438 'separable_3x3_2', 439 'avg_pool_3x3', 440 'none', 441 'avg_pool_3x3', 442 'avg_pool_3x3', 443 'separable_3x3_2', 444 'none'] 445 used_hiddenstates = [1, 0, 0, 0, 0, 0, 0] 446 hiddenstate_indices = [0, 1, 1, 1, 0, 1, 1, 1, 0, 0] 447 super(NasNetANormalCell, self).__init__(num_conv_filters, operations, 448 used_hiddenstates, 449 hiddenstate_indices, 450 drop_path_keep_prob, 451 total_num_cells, 452 total_training_steps) 453 454 455class NasNetAReductionCell(NasNetABaseCell): 456 """NASNetA Reduction Cell.""" 457 458 def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells, 459 total_training_steps): 460 operations = ['separable_5x5_2', 461 'separable_7x7_2', 462 'max_pool_3x3', 463 'separable_7x7_2', 464 'avg_pool_3x3', 465 'separable_5x5_2', 466 'none', 467 'avg_pool_3x3', 468 'separable_3x3_2', 469 'max_pool_3x3'] 470 used_hiddenstates = [1, 1, 1, 0, 0, 0, 0] 471 hiddenstate_indices = [0, 1, 0, 1, 0, 1, 3, 2, 2, 0] 472 super(NasNetAReductionCell, self).__init__(num_conv_filters, operations, 473 used_hiddenstates, 474 hiddenstate_indices, 475 drop_path_keep_prob, 476 total_num_cells, 477 total_training_steps) 478