1import tensorflow as tf
2
3
4def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
5                                    initializer=None, regularizer=None,
6                                    trainable=True,
7                                    *args, **kwargs):
8    """Custom variable getter that forces trainable variables to be stored in
9    float32 precision and then casts them to the training precision."""
10    storage_dtype = tf.float32 if trainable else dtype
11    variable = getter(name, shape, dtype=storage_dtype,
12                      initializer=initializer,
13                      regularizer=regularizer,
14                      trainable=trainable,
15                      *args, **kwargs)
16    if trainable and dtype != tf.float32:
17        cast_name = name + '/fp16_cast'
18        try:
19            cast_variable = tf.get_default_graph().get_tensor_by_name(
20                cast_name + ':0')
21        except KeyError:
22            cast_variable = tf.cast(variable, dtype, name=cast_name)
23        cast_variable._ref = variable._ref
24        variable = cast_variable
25    return variable
26
27
28class LossScalingOptimizer(tf.train.Optimizer):
29    """An optimizer that scales loss and un-scales gradients."""
30
31    def __init__(self, optimizer,
32                 scale=None,
33                 name="LossScalingOptimizer",
34                 use_locking=False):
35        super(LossScalingOptimizer, self).__init__(
36            name=name, use_locking=use_locking)
37        self._optimizer = optimizer
38        self._scale = float(scale) if scale is not None else 1.0
39
40    def compute_gradients(self, loss, var_list=None, *args, **kwargs):
41        if self._scale != 1.0:
42            loss = tf.scalar_mul(self._scale, loss)
43        gradvar = self._optimizer.compute_gradients(loss, var_list, *args, **kwargs)
44        gradvar = [(tf.scalar_mul(1. / self._scale, g), v) for g, v in gradvar]
45        return gradvar
46
47    def apply_gradients(self, *args, **kwargs):
48        return self._optimizer.apply_gradients(*args, **kwargs)
49