-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient of matmul in while_loop works when run eagerly but not as tf.function #35012
Comments
Issue replicating for TF 2.1, kindly find the gist of colab.Thanks! |
@allenlavoie does this ring any bells? |
We have not yet fixed higher-order gradients for cond/while under a tape AFAIK. I believe they need a similar fix to the one we applied to tf.function for recording/accepting gradients of side-outputs. CC @saxenasaurabh. That said, I'm not sure why this specific error is the one we see first. If we're not going to fix/test these soon we should probably start throwing an exception (we've had an internal bug open for a while). |
Since gradient tape with the tf.while_loop was not running with the tf.function decorator ,I move the the while loop(python while) into another function and then decorated it with the tf.function. It seems to be working. Please find the gist here : |
Wrapping in |
Fair enough, thank you all. |
Issue is replicating with TF 2.2-rc3 .Please, find the gist here.Thanks! |
Issue is replicating with TF 2.4.0-dev20200817 .Please, find the gist here.Thanks! |
Was able to replicate the issue in TF v2.5,please find the gist here..Thanks ! |
Was able to replicate the issue with TF 2.6.0-dev20210606,please find the gist here ..Thanks! |
I could reproduce the issue with TF 2.6 . Please find the |
I was able to replicate the issue in tf-nightly 2.12.0-dev20221215. Please find the gist for reference. Thank you. |
@mjwatkins2, |
Closing this as stale. Please reopen if this is still a valid request. Thank you! |
System information
Describe the current behavior
If I multiply tensors together in a while_loop I can compute higher order gradients when running eagerly, but not when running inside a function with the @tf.function decorator. If I manually unroll the loop, it works either way, so the problem must be related to the interaction between tf.function and tf.while_loop.
Describe the expected behavior
The output of this function should be the same regardless of the @tf.function decorator:
Code to reproduce the issue
Run without and then with @tf.function:
The text was updated successfully, but these errors were encountered: