Skip to content
gqlxj1987's Blog
Go back

Tensorflow Large batch

Edit page

GPU不够的情况下,

在TensorFlow上,我们可以比较方便地定制一个optimizer来实现这种操作,封装一下实际的optimizer,实际上做梯度累加和延迟更新两部就好了。

class LazyUpdateOptimizer(tf.train.Optimizer):
 
    def __init__(self, optimizer, batch_size=1,
                 use_locking=True, name="LazyUpdateOptimizer"):
 
        tf.train.Optimizer.__init__(self, use_locking=use_locking, name=name)
 
        self._name = name
        self._batch_size = batch_size
        self._grad_cache = {}
        self._optimizer = optimizer
        self._vars = []
 
        with tf.variable_scope(self._name):
            self._batch_count_variable = \
                tf.get_variable(name="batch_count",
                                shape=[],
                                dtype=tf.int64,
                                initializer=tf.constant_initializer(self._batch_size),
                                collections=[tf.GraphKeys.LOCAL_VARIABLES])
            self._vars.append(self._batch_count_variable)
 
    @property
    def optimizer(self):
        return self._optimizer
 
    @property
    def name(self):
        return self._name
 
    @property
    def batch_size(self):
        return self._batch_size
 
    def get_initializer(self):
        return tf.group([_.initializer for _ in self._vars])
 
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        scope_name = self._name
        if name is not None:
            scope_name += "_" + name
 
        cached_grads = []
        for grad, var in grads_and_vars:
            if grad is None:
                continue
 
            if var is not None and var not in self._grad_cache:
                with tf.variable_scope(scope_name):
                    with tf.colocate_with(var):
                        cached_grad = tf.get_variable(name=var.name.split(":")[0] + "_grad_cache",
                                                      dtype=var.dtype,
                                                      shape=var.shape,
                                                      initializer=tf.zeros_initializer(),
                                                      trainable=False,
                                                      collections=[tf.GraphKeys.LOCAL_VARIABLES])
                        self._vars.append(cached_grad)
                self._grad_cache[var] = cached_grad
            cached_grads.append(self._grad_cache[var])
 
        with tf.name_scope(scope_name):
            cache_gradients_op = self.__cache_gradients(grads_and_vars, cached_grads)
 
            with tf.control_dependencies([cache_gradients_op]):
                apply_op = tf.cond(
                    tf.equal(self._batch_count_variable, 0),
                    true_fn=lambda: self.__actual_apply_gradients(grads_and_vars, global_step=global_step),
                    false_fn=lambda: tf.no_op())
                with tf.control_dependencies([apply_op]):
                    return tf.no_op()
 
    def __cache_gradients(self, grads_and_vars, cached_grads):
        update_ops = []
        with tf.name_scope("cache_grad"):
            for (grad, var), cached_grad in itertools.izip(grads_and_vars, cached_grads):
                with tf.colocate_with(cached_grad):
                    if isinstance(grad, tf.Tensor):
                        update_op = tf.assign_add(cached_grad, grad)
                    elif isinstance(grad, tf.IndexedSlices):
                        update_op = tf.scatter_add(cached_grad, grad.indices, grad.values)
                    else:
                        continue
 
                update_ops.append(update_op)
            with tf.control_dependencies([tf.group(update_ops, name="record_gradients")]):
                return tf.assign_sub(self._batch_count_variable, 1)
 
    def __actual_apply_gradients(self, grads_and_vars, global_step=None):
        actual_grads_and_vars = [(self._grad_cache[var], var) for grad, var in grads_and_vars if grad is not None]
 
        apply_op = self._optimizer.apply_gradients(actual_grads_and_vars, global_step=global_step)
        with tf.control_dependencies([apply_op]):
            reset_ops = [tf.assign(self._batch_count_variable, self._batch_size)]
 
            for grad, var in actual_grads_and_vars:
                reset_ops.append(tf.assign(self._grad_cache[var], tf.zeros_like(var)))
 
            with tf.control_dependencies(reset_ops):
                return tf.no_op()

Edit page
Share this post on:

Previous Post
Spark Sql Join部分
Next Post
剑来的一