In Tensorflow how can I (1) compute gradients and (2) update variables in *separate* @tf.function methods?

I need to compute tf.Variable gradients in a class method, but use those gradients to update the variables at a later time, in a different method. I can do this when not using the @tf.function decorator, but I get the TypeError: An op outside of the function building code is being passed a "Graph" tensor error when using @tf.function. I’ve searched for understanding on this error and how to resolve it, but have come up short.

Just FYI if you’re curious, I want to do this because I have variables that are in numerous different equations. Rather than trying to create a single equation that relates all the variables, it is easier (less computationally costly) to keep them separate, compute the gradients at a moment in time for each of those equations, and then incrementally apply the updates. I recognize that these two approaches are not mathematically identical.

Here is my code (a minimal example), followed by the results and error message. Note that when gradients are computed and used to update variables in a single method, .iterate(), there is no error.

import tensorflow as tf

class Example():
    def __init__(self, x, y, target, lr=0.01):
        self.x = x
        self.y = y
        self.target = target
        self.lr = lr
        self.variables = [self.x, self.y]

    @tf.function
    def iterate(self):
        with tf.GradientTape(persistent=False) as tape:
            loss = (self.target - self.x * self.y)**2

        self.gradients = tape.gradient(loss, self.variables)
        for g, v in zip(self.gradients, self.variables):
            v.assign_add(-self.lr * g)

    @tf.function
    def compute_update(self):
        with tf.GradientTape(persistent=False) as tape:
            loss = (self.target - self.x * self.y)**2

        self.gradients = tape.gradient(loss, self.variables)

    @tf.function
    def apply_update(self):
        for g, v in zip(self.gradients, self.variables):
            v.assign_add(-self.lr * g)


x = tf.Variable(1.)
y = tf.Variable(3.)
target = tf.Variable(5.)

example = Example(x, y, target)

# Compute and apply updates in a single tf.function method
example.iterate()
print('')
print(example.variables)
print('')

# Compute and apply updates in separate tf.function methods
example.compute_update()
example.apply_update()
print('')
print(example.variables)
print('')

The output:

$ python temp_bug.py 

[<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.12>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.04>]

Traceback (most recent call last):
  File "temp_bug.py", line 47, in <module>
    example.apply_update()
  File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 650, in _call
    return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
  File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1665, in _filtered_call
    self.captured_inputs)
  File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1746, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 598, in call
    ctx=ctx)
  File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 75, in quick_execute
    raise e
  File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: gradient_tape/mul/Mul:0

Source: Python Questions

LEAVE A COMMENT