[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.dynamic_stitch gradient is incorrect #7397

Open
drasmuss opened this issue Feb 9, 2017 · 15 comments
Open

tf.dynamic_stitch gradient is incorrect #7397

drasmuss opened this issue Feb 9, 2017 · 15 comments
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug

Comments

@drasmuss
Copy link
Contributor
drasmuss commented Feb 9, 2017

If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)

Original reproduction code (TensorFlow 1.0)

import tensorflow as tf

x = tf.zeros((1, 3))
y = tf.dynamic_stitch([[0], [0]], [x, tf.ones((1, 3))])

with tf.Session() as sess:
    print("y")
    print(sess.run(y))

    analytic, numeric = tf.test.compute_gradient(x, (1, 3), y, (1, 3))
    print("analytic")
    print(analytic)
    print("numeric")
    print(numeric)

Updated reproduction code (TensorFlow 2.16)

import tensorflow as tf

x = tf.zeros((1, 3))

analytic, numeric = tf.test.compute_gradient(
    lambda x: tf.dynamic_stitch([[0], [0]], [x, tf.ones((1, 3))]), [x]
)
print("analytic")
print(analytic)
print("numeric")
print(numeric)

gives output

y
[[ 1.  1.  1.]]
analytic
[[ 1.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  1.]]
numeric
[[ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]]

The numeric gradient correctly shows that x has no impact on y (since the value of x is completely overwritten by a constant in the dynamic_stitch). The analytic gradient is incorrect; it seems like the gradient calculation in dynamic_stitch does not handle the case where there are duplicate indices being merged.

@girving
Copy link
Contributor
girving commented Feb 10, 2017

Ug. You're correct that the gradients are wrong, but I don't see how to fix it without a dramatic performance hit. Do you have any suggestions?

@girving girving added stat:awaiting response Status - Awaiting response from author type:bug Bug labels Feb 10, 2017
drasmuss added a commit to drasmuss/tensorflow that referenced this issue Feb 13, 2017
Addresses tensorflow#7397

Also expanded unit tests to cover these cases.
drasmuss added a commit to drasmuss/tensorflow that referenced this issue Feb 13, 2017
Addresses tensorflow#7397

Also expanded unit tests to cover these cases.
drasmuss added a commit to drasmuss/tensorflow that referenced this issue Feb 14, 2017
Addresses tensorflow#7397

Expanded unit tests to cover these cases.
drasmuss added a commit to drasmuss/tensorflow that referenced this issue Feb 14, 2017
Addresses tensorflow#7397

Expanded unit tests to cover these cases.
drasmuss added a commit to drasmuss/tensorflow that referenced this issue Feb 14, 2017
Addresses tensorflow#7397

Expanded unit tests to cover these cases.
drasmuss added a commit to drasmuss/tensorflow that referenced this issue Feb 15, 2017
Addresses tensorflow#7397

Expanded unit tests to cover these cases.
drasmuss added a commit to drasmuss/tensorflow that referenced this issue Feb 16, 2017
Addresses tensorflow#7397

Expanded unit tests to cover these cases.
@aselle
Copy link
Contributor
aselle commented Mar 3, 2017

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

@aselle aselle closed this as completed Mar 3, 2017
@drasmuss
Copy link
Contributor Author
drasmuss commented Mar 4, 2017

The bug still exists, meaning that the tf.dynamic_stitch gradients are incorrect. Is there any other information I can provide that would be helpful?

@girving girving reopened this Mar 6, 2017
@girving
Copy link
Contributor
girving commented Mar 6, 2017

Let's leave this open. Anyone interested should refer to the comments in #7487. The next step would have been to add a new C++ kernel to speed up the bookkeeping required by accurate gradients.

@girving girving added stat:contribution welcome Status - Contributions welcome and removed stat:awaiting response Status - Awaiting response from author labels Mar 6, 2017
@bhack
Copy link
Contributor
bhack commented May 13, 2021

Can we close this?

@drasmuss
Copy link
Contributor Author

The gradient implementation is still incorrect, as of TF 2.5.0rc. Here is an updated example showing the same error

import tensorflow as tf

x = tf.zeros((1, 3))

analytic, numeric = tf.test.compute_gradient(
    lambda x: tf.dynamic_stitch([[0], [0]], [x, tf.ones((1, 3))]), [x]
)
print("analytic")
print(analytic)
print("numeric")
print(numeric)

gives

analytic
(array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32),)
numeric
(array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32),)

@bhack
Copy link
Contributor
bhack commented May 13, 2021

/cc @rthadur Can we update the label?

@rthadur rthadur removed the stat:contribution welcome Status - Contributions welcome label May 13, 2021
@ymodak ymodak added comp:ops OPs related issues TF 2.5 Issues related to TF 2.5 stat:awaiting tensorflower Status - Awaiting response from tensorflower labels May 19, 2021
@chunduriv
Copy link
Contributor

I was able to replicate issues in tf-nightly 2.10.0-dev20220714. Please find the gist for reference. Thank you

@chunduriv chunduriv added TF 2.9 Issues found in the TF 2.9 release (or RCs) and removed TF 2.5 Issues related to TF 2.5 labels Jul 15, 2022
@chunduriv chunduriv assigned chunduriv and unassigned ymodak Jul 15, 2022
@chunduriv chunduriv assigned mohantym and unassigned chunduriv Sep 29, 2022
@mohantym
Copy link
Contributor
mohantym commented Sep 30, 2022

Hi @drasmuss !
Just trying to put my observation based on yesterday's trials.
Actually @drasmuss has used same indices [0] , [0] to stitch to two tensors ([0,0,0] and [1,1,1]. So while trying to allocate a tensor at same index , it is taking the max of two tensors and failing the dynamic stitch.

If we use two different indices like [0] , [1] or [0],[2] then results for analytical and theoritical from test.comput_gradient is same.
@yongtang @bhack
May be we can put assertion condition in code itself to check whether user is putting different indices or not .

Attached gist for reference.

Thank you!

@drasmuss
Copy link
Contributor Author
drasmuss commented Sep 30, 2022

Yes, this bug is caused by having duplicate indices. But that is defined and supported behaviour for dynamic stitch (e.g., see the documentation):

Values are merged in order, so if an index appears in both indices[m][i] and indices[n][j] for (m,i) < (n,j) the slice data[n][j] will appear in the merged result

@mohantym
Copy link
Contributor

Ok @drasmuss !
Thanks for the update.

@mohantym mohantym removed their assignment Oct 3, 2022
@pjpratik
Copy link
Contributor

I was able to reproduce this issue in TF Nighly 2.12.0-dev20221218. Please find the gist here. Thank you.

@synandi
Copy link
Contributor
synandi commented Apr 20, 2023

I was able to replicate this issue in TF Nighly 2.13.0-dev20230419. Please find the gist here. Thank you.

@sushreebarsa
Copy link
Contributor

Hi,

Thank you for opening this issue. Since this issue has been open for a long time, the code/debug information for this issue may not be relevant with the current state of the code base.

The Tensorflow team is constantly improving the framework by fixing bugs and adding new features. We suggest you try the latest TensorFlow version with the latest compatible hardware configuration which could potentially resolve the issue. If you are still facing the issue, please create a new GitHub issue with your latest findings, with all the debugging information which could help us investigate.

Please follow the release notes to stay up to date with the latest developments which are happening in the Tensorflow space.

Thank you!

@sushreebarsa sushreebarsa self-assigned this May 21, 2024
@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label May 21, 2024
@drasmuss
Copy link
Contributor Author

This bug is still present in TensorFlow 2.16.1. The code from #7397 (comment) is still valid, and reproduces the bug. I have edited that into the original post for clarity.

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label May 21, 2024
@sushreebarsa sushreebarsa removed their assignment May 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug
Projects
None yet
Development

No branches or pull requests