1import tensorflow as tf 2 3 4def float32_variable_storage_getter(getter, name, shape=None, dtype=None, 5 initializer=None, regularizer=None, 6 trainable=True, 7 *args, **kwargs): 8 """Custom variable getter that forces trainable variables to be stored in 9 float32 precision and then casts them to the training precision.""" 10 storage_dtype = tf.float32 if trainable else dtype 11 variable = getter(name, shape, dtype=storage_dtype, 12 initializer=initializer, 13 regularizer=regularizer, 14 trainable=trainable, 15 *args, **kwargs) 16 if trainable and dtype != tf.float32: 17 cast_name = name + '/fp16_cast' 18 try: 19 cast_variable = tf.get_default_graph().get_tensor_by_name( 20 cast_name + ':0') 21 except KeyError: 22 cast_variable = tf.cast(variable, dtype, name=cast_name) 23 cast_variable._ref = variable._ref 24 variable = cast_variable 25 return variable 26 27 28class LossScalingOptimizer(tf.train.Optimizer): 29 """An optimizer that scales loss and un-scales gradients.""" 30 31 def __init__(self, optimizer, 32 scale=None, 33 name="LossScalingOptimizer", 34 use_locking=False): 35 super(LossScalingOptimizer, self).__init__( 36 name=name, use_locking=use_locking) 37 self._optimizer = optimizer 38 self._scale = float(scale) if scale is not None else 1.0 39 40 def compute_gradients(self, loss, var_list=None, *args, **kwargs): 41 if self._scale != 1.0: 42 loss = tf.scalar_mul(self._scale, loss) 43 gradvar = self._optimizer.compute_gradients(loss, var_list, *args, **kwargs) 44 gradvar = [(tf.scalar_mul(1. / self._scale, g), v) for g, v in gradvar] 45 return gradvar 46 47 def apply_gradients(self, *args, **kwargs): 48 return self._optimizer.apply_gradients(*args, **kwargs) 49