[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong gradient in forward mode auto differentiation for tf.random.stateless_parameterized_truncated_normal #56842

Open
eeDigitalSeal opened this issue Jul 20, 2022 · 5 comments
Assignees
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug

Comments

@eeDigitalSeal
Copy link
eeDigitalSeal commented Jul 20, 2022
Click to expand!

Issue Type

Bug

Source

binary

Tensorflow Version

tf 2.9

Custom Code

Yes

OS Platform and Distribution

Linux Ubuntu 20.04

Mobile device

No response

Python version

3.9

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current Behaviour?

Jacobian matrix elements are not equal in forward mode and backward mode with the same input.

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np

shape = [2, 3]
seed = [7, 17]
means = 13.0
stddevs = tf.constant([[0.8059583, 0.09676647, 0.08382106],
                       [0.8149866, 0.44204712, 0.5636599]], dtype=tf.float32)
minvals = [-1.0, -2.0, -1000.0]
maxvals = [[10000.0], [1.0]]
name = None
with tf.GradientTape(persistent=True, ) as g:
  g.watch(stddevs)
  tf.random.set_seed(42)
  res_backward = tf.random.stateless_parameterized_truncated_normal(shape, seed, means=means, stddevs=stddevs,
                                                                    minvals=minvals, maxvals=maxvals, )
# shape=(2,3,2,3)
jacobian = g.jacobian(res_backward,stddevs)
print(jacobian[0][1])

tangents = tf.constant([[0.,1.,0.],
 [0.,0.,0.]],shape=(2,3),dtype=tf.float32)
with tf.autodiff.ForwardAccumulator(stddevs,tangents) as acc:
  res_forward = tf.random.stateless_parameterized_truncated_normal(shape, seed, means=means, stddevs=stddevs,
                                                                   minvals=minvals, maxvals=maxvals, )
print(acc.jvp(res_forward))

Relevant log output

tf.Tensor(
[[ 0.         0.8470439  0.       ]
 [-0.        -0.        -0.       ]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[0.        0.8470439 0.       ]
 [      nan       nan       nan]], shape=(2, 3), dtype=float32)
@tilakrayal
Copy link
Contributor

@eeDigitalSeal,
I tried to provide the alternative approach for stddevs and executed the code without any issues with the expected result. Kindly find the gist of it here.

The values follow a normal distribution with specified `mean` and `standard deviation`, except that values 

whose magnitude is more than 2 standard deviations from the mean are dropped and re-picked.

The standard deviation of the truncated normal distribution. This must broadcast with means, minvals and maxvals, and the broadcasted shape must be dominated by shape. Thank you!

@tilakrayal tilakrayal added TF 2.9 Issues found in the TF 2.9 release (or RCs) comp:ops OPs related issues stat:awaiting response Status - Awaiting response from author labels Jul 21, 2022
@eeDigitalSeal
Copy link
Author

Thanks @tilakrayal for your reply. I still don't get it why the code I provided give different gradient in forward mode and backward mode. I can see you changed the stddevs to be tf.math.exp(tf.constant(...)), do you mean that the original values for stddevs is invalid?

stddevs = tf.constant([[0.8059583, 0.09676647, 0.08382106],
                       [0.8149866, 0.44204712, 0.5636599]], dtype=tf.float32)

But I think this stateless_parameterized_truncated_normal should be able to accept any positive stddevs, including this tensor.

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Jul 23, 2022
@tilakrayal
Copy link
Contributor

@gowthamkpr,
I was able to reproduce the issue on tensorflow v2.8, v2.9 and nightly. Kindly find the gist of it here.

@tilakrayal tilakrayal assigned gowthamkpr and unassigned tilakrayal Jul 26, 2022
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Aug 4, 2022
@cantonios
Copy link
Contributor

@eeDigitalSeal The issue here is that NaNs are actually encountered in the forward pass due to your choice of mean and min/max vals. You'll see NaNs generated in backward gradients if you take the gradient w.r.t. minval, for example. Once we hit a NaN, we can't really recover. This is avoided in the backward gradient for stddev.

You'll get a correct answer if you increase the precision to tf.float64.

@eeDigitalSeal
Copy link
Author

@eeDigitalSeal The issue here is that NaNs are actually encountered in the forward pass due to your choice of mean and min/max vals. You'll see NaNs generated in backward gradients if you take the gradient w.r.t. minval, for example. Once we hit a NaN, we can't really recover. This is avoided in the backward gradient for stddev.

You'll get a correct answer if you increase the precision to tf.float64.

@cantonios Thanks for your reply. But If I change the input dtype to tf.float64, this API would throw the following error. It seems that it only accepts float tensors as input.

tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute StatelessParameterizedTruncatedNormal as input #3(zero-based) was expected to be a float tensor but is a double tensor [Op:StatelessParameterizedTruncatedNormal]

And I think if backward gradient hits a NaN, it should give the original NaN value rather than handle it silently inside the framework?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants