[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of matmul in while_loop works when run eagerly but not as tf.function #35012

Closed
mjwatkins2 opened this issue Dec 11, 2019 · 15 comments
Closed
Assignees
Labels
comp:core issues related to core part of tensorflow comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug

Comments

@mjwatkins2
Copy link

System information

  • Have I written custom code: Example provided below
  • OS Platform and Distribution: Both Windows 10 and Google Colab
  • TensorFlow installed from binary
  • TensorFlow version: Both 2.0.0 and 2.1.0-rc0
  • Python version: 3.6 and 3.7.4

Describe the current behavior
If I multiply tensors together in a while_loop I can compute higher order gradients when running eagerly, but not when running inside a function with the @tf.function decorator. If I manually unroll the loop, it works either way, so the problem must be related to the interaction between tf.function and tf.while_loop.

TypeError: in converted code:

<ipython-input-17-50c1c9a72b0f>:17 func  *
    d3y_dx3 = t.gradient(d2y_dx2, x)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/backprop.py:999 gradient
    if not backprop_util.IsTrainable(t):
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/backprop_util.py:30 IsTrainable
    dtype = dtypes.as_dtype(dtype)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/dtypes.py:725 as_dtype
    (type_value,))

TypeError: Cannot convert value None to a TensorFlow DType.

Describe the expected behavior
The output of this function should be the same regardless of the @tf.function decorator:

[[0 0 4]]
[[[-2 1 10]]]
[[[8 2 20]]]
[[[-18 6 30]]]
[[[24 24 24]]]

Code to reproduce the issue
Run without and then with @tf.function:

import tensorflow as tf

x = tf.Variable([[[-1.0]], [[0.0]], [[1.0]]])

# @tf.function
def func():
    with tf.GradientTape(persistent=True) as t:
        t.watch(x)
        pv = x
        
        n = 4
        c = lambda i, _: i < n
        b = lambda i, pv: (i+1, tf.concat([x, tf.matmul(pv, x)], axis=1))
        _, pv = tf.while_loop(c, b, (tf.constant(1), pv), shape_invariants=(tf.constant(1).shape, tf.TensorShape([3, None, 1])))

        # manual loop unrolling works fine both ways:
        # pv = tf.concat([x, tf.matmul(pv, x)], axis=1)
        # pv = tf.concat([x, tf.matmul(pv, x)], axis=1)
        # pv = tf.concat([x, tf.matmul(pv, x)], axis=1)
        
        y = tf.reduce_sum(pv, axis=1)
        dy_dx = t.gradient(y, x)
        d2y_dx2 = t.gradient(dy_dx, x)
        d3y_dx3 = t.gradient(d2y_dx2, x)
        d4y_dx4 = t.gradient(d3y_dx3, x)
    del t

    tf.print(tf.transpose(y)) # transpose to print on one line
    tf.print(tf.transpose(dy_dx))
    tf.print(tf.transpose(d2y_dx2))
    tf.print(tf.transpose(d3y_dx3))
    tf.print(tf.transpose(d4y_dx4))

func()
@oanush oanush self-assigned this Dec 12, 2019
@oanush oanush added comp:autograph Autograph related issues TF 2.1 for tracking issues in 2.1 release labels Dec 12, 2019
@goldiegadde goldiegadde added the TF 2.0 Issues relating to TensorFlow 2.0 label Dec 12, 2019
@oanush
Copy link
oanush commented Dec 12, 2019

Issue replicating for TF 2.1, kindly find the gist of colab.Thanks!

@oanush oanush added type:bug Bug and removed TF 2.0 Issues relating to TensorFlow 2.0 labels Dec 12, 2019
@oanush oanush assigned gowthamkpr and unassigned oanush Dec 12, 2019
@gowthamkpr gowthamkpr assigned mdanatg and unassigned gowthamkpr Dec 12, 2019
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Dec 12, 2019
@mdanatg mdanatg assigned alextp and unassigned mdanatg Dec 13, 2019
@mdanatg mdanatg added comp:core issues related to core part of tensorflow comp:ops OPs related issues and removed comp:autograph Autograph related issues labels Dec 13, 2019
@alextp
Copy link
Contributor
alextp commented Dec 13, 2019

@allenlavoie does this ring any bells?

@allenlavoie
Copy link
Member

We have not yet fixed higher-order gradients for cond/while under a tape AFAIK. I believe they need a similar fix to the one we applied to tf.function for recording/accepting gradients of side-outputs. CC @saxenasaurabh. tf.gradients likely works.

That said, I'm not sure why this specific error is the one we see first. If we're not going to fix/test these soon we should probably start throwing an exception (we've had an internal bug open for a while).

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Dec 14, 2019
@Athul8raj
Copy link

Since gradient tape with the tf.while_loop was not running with the tf.function decorator ,I move the the while loop(python while) into another function and then decorated it with the tf.function. It seems to be working.

Please find the gist here :
https://colab.research.google.com/drive/1Y_0EfYzkA6dWjCCL02mz3vOgJRdwpqd1

@saxenasaurabh saxenasaurabh self-assigned this Dec 16, 2019
@saxenasaurabh
Copy link
Member

Wrapping in tf.function sounds reasonable till we have a proper fix. Runtime performance would be similar since we would inline the inner tf.function anyway.

@mjwatkins2
Copy link
Author

Fair enough, thank you all.

@ravikyram
Copy link
Contributor

Issue is replicating with TF 2.2-rc3 .Please, find the gist here.Thanks!

@Saduf2019
Copy link
Contributor

Issue is replicating with TF 2.4.0-dev20200817 .Please, find the gist here.Thanks!

@sushreebarsa
Copy link
Contributor

Was able to replicate the issue in TF v2.5,please find the gist here..Thanks !

@sushreebarsa
Copy link
Contributor

Was able to replicate the issue with TF 2.6.0-dev20210606,please find the gist here ..Thanks!

@kumariko
Copy link

I could reproduce the issue with TF 2.6 . Please find the gist here.Thanks!

@saxenasaurabh saxenasaurabh removed their assignment Dec 13, 2021
@Saduf2019 Saduf2019 added the TF 2.7 Issues related to TF 2.7.0 label Dec 20, 2021
@chunduriv
Copy link
Contributor
chunduriv commented Jul 26, 2022

I was able to replicate the issue in tf-nightly 2.12.0-dev20221215. Please find the gist for reference. Thank you.

@chunduriv chunduriv added TF 2.9 Issues found in the TF 2.9 release (or RCs) stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed TF 2.1 for tracking issues in 2.1 release TF 2.7 Issues related to TF 2.7.0 labels Jul 26, 2022
@tilakrayal
Copy link
Contributor

@mjwatkins2,
I tried to execute the mentioned code in the alternative workaround and it was executed without any issue/error on tensorflow nightly(2.13.0-dev20230305). Kindly find the gist of it here. Thank you!

@tilakrayal tilakrayal added the stat:awaiting response Status - Awaiting response from author label Mar 6, 2023
@tilakrayal tilakrayal self-assigned this Mar 15, 2023
@tilakrayal tilakrayal added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Mar 15, 2023
@tilakrayal
Copy link
Contributor

Closing this as stale. Please reopen if this is still a valid request. Thank you!

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:core issues related to core part of tensorflow comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug
Projects
None yet
Development

No branches or pull requests