-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numerical instability of gradient calculation of tf.norm (nan at 0, inf for small values) #12071
Comments
This is caused by square root in definition of tf.norm. IE you are taking gradient of sqrt(x^2). Gradient of sqrt approaches infinity, whereas gradient of x^2 approaches 0, so computing them separately then multiplying is a problem. @goodfeli -- do you remember if Theano uses some standard stabilizing transformation for this kind of case? |
Theano gives NaN as the gradient of the norm of a vector with zero norm:
Theano does give 0 as the gradient of the norm of a scalar. I think for this it is probably using a patternsub to turn sqrt(square(x)) into abs(x). Note that in Theano's conventions, the derivative of abs(x) at 0 is treated as 0 rather than undefined. |
Thinking philosophically, the general problem is that computational graph ends up with things like You have to do some algebraic massaging to get numerically defined result. In your particular case, you could replace automatic gradient with a stable version:
|
First @yaroslavvb thanks a lot for code for the stable version of the gradient (I did not know that it's possible to overwrite the gradient calculation). I agree that the corner case 0 is not really defined but the same is true for the gradient of a ReLU at 0. For I did not include this in my original report, but the problem is also present for non-scalars. See
which gives |
Yes, I agree that this should be fixed, the question is how. The bug with tf.select-gradient-caused NaN's linked above is still open a year later and it's a similar problem. Python-specific solution would be to replace norm with a fused version like above, and corresponding gradient. Fused versions are not really preferred but are done if it's important enough, like was done for More general solution is to have a numerically stabilizing transformation which simplifies expressions that lead to "0/0" or "infinity*0" to behave more like their limits. @benoitsteiner @petewarden -- do you know if there's anything on numerically stabilizing graph transformations on the roadmap? |
+1 Already several times I've been hitting that. |
@yaroslavvb Hi, I don't understand why ## dy*(x/tf.norm(x)) ## this could lead to a stable gradient. Could you please explain it more? Thank you. |
+1 I'm also facing this issue while computing eucledian norm for the paper : A Structured Self-attentive Sentence Embedding Edit : Pytorch applied a fix for it. May be tf can do the same https://github.com/pytorch/pytorch/pull/2775/files |
Any update on this? |
Bump for my honors thesis XD |
Assigning gradient issue to @girving |
I posted a stable implementation above which you can substitute instead of TF default version. Don't know if this works for an external contributor, that depends if there internal tests which may be broken by changing numerics of tf.norm |
Marking as contributions welcome, I gave an example solution above, someone wanting to be a part of TF could work the solution into tensorflow. |
Hi @yaroslavvb ,
It will give output:
Using your suggestion:
I still get:
So here's my simple modification that seems to work (just by adding a small epsilon to the gradient's denominator):
This will gives:
|
@yaroslavvb Thanks for the snippet!. Unfortunately, with eager execution enabled, The following works with eager execution enabled
|
Thanks, @gsutanto for your solution. I wanted to ask what the correct way to force calculate gradients is when there is a complicated interaction of variables, such as when we're using gradient descent to optimize a loss function term with the norm as penalty. Do I have to force the gradient of the entire loss function or should only the gradient of norm be computed separately? Drawing from your code, what I'm saying is something like in the case below: ...
X = tf.placeholder(tf.float32, shape=(4, 1))
Y = tf.placeholder(tf.float32, shape=(1, 1))
W = tf.Variable(tf.zeros(shape=[1, 4]))
Z = norm(X)
var_grad = tf.gradients(Z, [X])
J = tf.reduce_mean((Y - tf.matmul(W, X))) + (.5 * Z)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
X_ = np.array([
[0],
[0],
[0],
[0]
], dtype=np.float32)
Y_ = np.array([
[1]
], dtype=np.float32)
optimizer = tf.train.GradientDescentOptimizer(1).minimize(J)
sess.run([optimizer, J, var_grad], feed_dict={X: X_, Y: Y_}) |
Hi, @Dust0x if you work on a batch of data, it would be something like:
(say if X_ were NxM=4x3; each row is a feature vector of size 3 that you would like to compute norm() on, and there are N=4 samples; there might be a better way to do this, if anyone knows, please post your code here; in terms of TensorFlow, I believe it will automatically use the newly defined norm gradient in any cost functions you are using it; please let me know if my understanding is incorrect...) |
Thanks again, @gsutanto. Luckily for me, the input to the norm only needs to be fed once in my case and so I don't (yet) have to worry about batch computations. The way I am currently solving this is by simply calling the forced gradient of norm in the same session as the optimizer and if what you say is correct — nothing seems to have broken so far, so it must be — this should work just fine. |
Hi, I am hitting this too when trying to use the pairwise distance norm of a batch of vectors, here is the SO question, am I getting this issue because of this bug? |
I am getting the same bug on a distance norm of a batch of vectors. We have very similar situations. -- EDIT: Although this is certainly not the same, I found that taking the absolute value was without the gradient instabilities. I.e. |x_1| + |x_2| + |x_3| + ... |
Hi, @gsutanto, @gokul-uf , I've tried to implement your code, like this: `@tf.custom_gradient def custom_norm1(X): I've add the
I can tell it is beacuse know I have to multiply that tensor by |
@macarena1807, I believe you can apply
|
Prevent losses from diverging as per tensorflow/tensorflow#12071.
Thank you for the solution @yaroslavvb, @gsutanto and @gokul-uf. I needed the L2 Norm for a personal project and I was having a lot of issues. During the first epochs the network was stable and even able to converge, and then, suddenly the loss exploded to infinity, without any kind of warning or symptoms. I testet every technique I knew to try to solve the problem without really knowing what was going on, including gradient clipping, regularization, different architectures, activation functions and initialization techniques, residual connections, different optimizers and learning rates, and so on, but nothing really worked. Then keeping the same structure I changed the Norm to L1 and everything worked fine, however L2 was a necessary requirement. This solution completely solved the exploding gradient issue while using the L2 Norm, stabilized the training and provided the best results so far. Imports:
Custom Keras L2 Norm Layer with custom gradient to prevent numerical instability:
Example of how to apply layer in a
|
This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you. |
I wonder if this issue still exists with tf.norm?It seems that I should use tf.nn.l2_normalize instead,since I see hanxiao has made a fix for it? |
Hi, A common and effective approach is to add a small epsilon value (ε) to the squared sum before taking the square root in the norm calculation. This prevents division by zero and improves stability for very small values. Here's how you can modify your code:
Since this issue has been open for a long time, the code/debug information for this issue may not be relevant with the current state of the code base. The Tensorflow team is constantly improving the framework by fixing bugs and adding new features. We suggest you try the latest TensorFlow version with the latest compatible hardware configuration which could potentially resolve the issue. If you are still facing the issue, please create a new GitHub issue with your latest findings, with all the debugging information which could help us investigate. Please follow the release notes to stay up to date with the latest developments which are happening in the Tensorflow space. Thank you! |
Hi, If the standard implementation is still not working, it will still cause problems for users of the library. |
Hi, If the standard implementation is still not working, it will still cause problems for users of the library. (I accidentally closed it in my comment before) |
1 similar comment
Hi, If the standard implementation is still not working, it will still cause problems for users of the library. (I accidentally closed it in my comment before) |
System information
Describe the problem
nan
is calculated for the gradient oftf.norm
at zero values. For extremely small valuesinf
is calculated. Note that the exact result should be 1 in all cases above.Above is a minimal example to reproduce it. The problem occurred in a real world scenario, when implementing a custom loss function (the entropy in https://arxiv.org/abs/1611.01449) and two embeddings where too close to each other (distance practically 0).
Source code / logs
See above
Output of logfile
The text was updated successfully, but these errors were encountered: