[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unable to create tf.variables inside a function that is decorated with @tf.function #51453

Open
sjtusmartboy opened this issue Aug 12, 2021 · 3 comments
Assignees
Labels
comp:tf.function tf.function related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:bug Bug

Comments

@sjtusmartboy
Copy link
sjtusmartboy commented Aug 12, 2021

tf 2.5

@tf.function
def weight_fn():
    w = tf.Variable(tf.truncated_normal())

I have a function like above that would be called about 50 times, each time it should generate a new variable and return. But according to the rule and the hint below,

    ValueError: A tf.Variable created inside your tf.function has been garbage-collected. Your code needs to keep Python references to variables created inside `tf.function`s.
    
    A common way to raise this error is to create and return a variable only referenced inside your function:
    
    @tf.function
    def f():
      v = tf.Variable(1.0)
      return v
    
    v = f()  # Crashes with this error message!
    
    The reason this crashes is that @tf.function annotated function returns a **`tf.Tensor`** with the **value** of the variable when the function is called rather than the variable instance itself. As such there is no code holding a reference to the `v` created inside the function and Python garbage collects it.
    
    The simplest way to fix this issue is to create variables outside the function and capture them:
    
    v = tf.Variable(1.0)
    
    @tf.function
    def f():
      return v
    
    f()  # <tf.Tensor: numpy=1.>
    v.assign_add(1.)
    f()  # <tf.Tensor: numpy=2.>

I should define the weight variable outside the tf.function, which means I should manually define over 50 weight variables, each line with a weight variable.

w1 = tf.Variable(tf.truncated_normal())
w2 = tf.Variable(tf.truncated_normal())
w3 = tf.Variable(tf.truncated_normal())
......
w50 = tf.Variable(tf.truncated_normal())

Undoubtedly, this kind of behavior is really stupid, any solutions to this kind of unreasonable rule?

@sanatmpa1 sanatmpa1 self-assigned this Aug 12, 2021
@sanatmpa1 sanatmpa1 added TF 2.5 Issues related to TF 2.5 comp:tf.function tf.function related issues labels Aug 12, 2021
@sanatmpa1 sanatmpa1 assigned ymodak and unassigned sanatmpa1 Aug 13, 2021
@sumanthratna
Copy link

not sure if you saw #49310 (comment), but ALLOW_DYNAMIC_VARIABLE_CREATION may help: (untested)

from tensorflow.python.eager import def_function  # def_function.function is the same as tf.function
from tensorflow.python.ops import variables


def_function.ALLOW_DYNAMIC_VARIABLE_CREATION = True

vars = {}

@def_function.function
def weight_fn(val, key):
    if key not in vars:
      vars[key] = variables.Variable(val)

weights = [weight_fn(tf.truncated_normal(), f"w{ind+1}") for ind in range(50)]

@sjtusmartboy
Copy link
Author

@sumanthratna Thanks for your remind.You are very kind. It seems this is the best solution we can get for now.

@bhack
Copy link
Contributor
bhack commented Aug 24, 2021

/cc @mdanatg

@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Aug 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:tf.function tf.function related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants