1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
| class LazyUpdateOptimizer(tf.train.Optimizer): def __init__(self, optimizer, batch_size=1, use_locking=True, name="LazyUpdateOptimizer"): tf.train.Optimizer.__init__(self, use_locking=use_locking, name=name) self._name = name self._batch_size = batch_size self._grad_cache = {} self._optimizer = optimizer self._vars = [] with tf.variable_scope(self._name): self._batch_count_variable = \ tf.get_variable(name="batch_count", shape=[], dtype=tf.int64, initializer=tf.constant_initializer(self._batch_size), collections=[tf.GraphKeys.LOCAL_VARIABLES]) self._vars.append(self._batch_count_variable) @property def optimizer(self): return self._optimizer @property def name(self): return self._name @property def batch_size(self): return self._batch_size def get_initializer(self): return tf.group([_.initializer for _ in self._vars]) def apply_gradients(self, grads_and_vars, global_step=None, name=None): scope_name = self._name if name is not None: scope_name += "_" + name cached_grads = [] for grad, var in grads_and_vars: if grad is None: continue if var is not None and var not in self._grad_cache: with tf.variable_scope(scope_name): with tf.colocate_with(var): cached_grad = tf.get_variable(name=var.name.split(":")[0] + "_grad_cache", dtype=var.dtype, shape=var.shape, initializer=tf.zeros_initializer(), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) self._vars.append(cached_grad) self._grad_cache[var] = cached_grad cached_grads.append(self._grad_cache[var]) with tf.name_scope(scope_name): cache_gradients_op = self.__cache_gradients(grads_and_vars, cached_grads) with tf.control_dependencies([cache_gradients_op]): apply_op = tf.cond( tf.equal(self._batch_count_variable, 0), true_fn=lambda: self.__actual_apply_gradients(grads_and_vars, global_step=global_step), false_fn=lambda: tf.no_op()) with tf.control_dependencies([apply_op]): return tf.no_op() def __cache_gradients(self, grads_and_vars, cached_grads): update_ops = [] with tf.name_scope("cache_grad"): for (grad, var), cached_grad in itertools.izip(grads_and_vars, cached_grads): with tf.colocate_with(cached_grad): if isinstance(grad, tf.Tensor): update_op = tf.assign_add(cached_grad, grad) elif isinstance(grad, tf.IndexedSlices): update_op = tf.scatter_add(cached_grad, grad.indices, grad.values) else: continue update_ops.append(update_op) with tf.control_dependencies([tf.group(update_ops, name="record_gradients")]): return tf.assign_sub(self._batch_count_variable, 1) def __actual_apply_gradients(self, grads_and_vars, global_step=None): actual_grads_and_vars = [(self._grad_cache[var], var) for grad, var in grads_and_vars if grad is not None] apply_op = self._optimizer.apply_gradients(actual_grads_and_vars, global_step=global_step) with tf.control_dependencies([apply_op]): reset_ops = [tf.assign(self._batch_count_variable, self._batch_size)] for grad, var in actual_grads_and_vars: reset_ops.append(tf.assign(self._grad_cache[var], tf.zeros_like(var))) with tf.control_dependencies(reset_ops): return tf.no_op()
|