[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nan gradient issue #42889

Closed
j7168908jx opened this issue Sep 2, 2020 · 14 comments
Closed

nan gradient issue #42889

j7168908jx opened this issue Sep 2, 2020 · 14 comments
Assignees
Labels
comp:keras Keras related issues comp:ops OPs related issues stat:awaiting response Status - Awaiting response from author TF 2.3 Issues related to TF 2.3 type:bug Bug

Comments

@j7168908jx
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): win 10
  • TensorFlow installed from (source or binary): pip
  • TensorFlow version (use command below): 2.2.0
  • Python version: 3.6.8 not using gpu

Describe the current behavior
training on an easy example, tf sometimes got nan for gradient
Describe the expected behavior

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np
import time
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# on cpu nan at 1098 iter
# Iter    0: loss 9.99180138e-01,  runtime:     0.11
# Iter    1: loss 9.98347044e-01,  runtime:     0.11
# Iter    2: loss 9.97434497e-01,  runtime:     0.12
# on gpu nan at 2212 iter
# Iter    0: loss 9.99180079e-01,  runtime:     2.46
# Iter    1: loss 9.98346925e-01,  runtime:     2.47
# Iter    2: loss 9.97434497e-01,  runtime:     2.48

np.random.seed(1)
tf.keras.backend.set_floatx('float32')

func = lambda x: x
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-1)

data = np.array(
    [[0., 0.07179281], [0., 0.44064897], [0., 0.7666122], [0., -0.655319],
     [0., -0.28546047], [0., 0.8460491], [0., 0.14823522], [0., -0.14381762],
     [0., 0.7200559], [0., -0.92189044], [0., 0.37300184], [0., -0.525946],
     [0., 0.07766213], [0., 0.370439], [0., 0.17311008], [0., 0.88918954],
     [0., -0.5910955], [0., -0.947578], [0., -0.7192261], [0., 0.5109261],
     [0., 0.85887444], [0., -0.75145805], [0., 0.89897853], [0., 0.23428982],
     [0., 0.5785587], [0., 0.0541162], [0., 0.97772217], [0., 0.24339144],
     [0., -0.72505057], [0., -0.39533487], [0., 0.6692513], [0., -0.7257285],
     [0., 0.93652314], [0., 0.17861107], [0., 0.38464522], [0., 0.38880032],
     [0., -0.73994285], [0., -0.9602397], [0., 0.07763347], [0., 0.6147826],
     [0., 0.68406177], [0., 0.39951673], [0., -0.17188802], [0., -0.10017573],
     [0., 0.7917724], [0., 0.35767105], [0., 0.7892133], [0., -0.62747955],
     [0., 0.7562349], [0., -0.16161098], [0., -0.77050805], [0., 0.8068038],
     [0., -0.37315163], [0., -0.3467102], [0., -0.70654285], [0., -0.8679997],
     [0., 0.5002886], [0., -0.7214473], [0., 0.7718842], [0., -0.5767438],
     [0., 0.8550172], [0., 0.4230495], [0., -0.7064882], [0., 0.11737966],
     [0., 0.326883], [0., -0.439112], [0., -0.99425936], [0., -0.94338703],
     [0., -0.8153228], [0., 0.8651909], [0., -0.96342343], [0., 0.9296801],
     [0., -0.50757784], [0., 0.24734442], [0., 0.80675906], [0., 0.38375422],
     [0., -0.7953311], [0., -0.4127717], [0., 0.39363632], [0., -0.30887854],
     [0., -0.8299116], [0., -0.603797], [0., -0.9452248], [0., -0.80330634],
     [0., 0.34093502], [0., -0.793548], [0., 0.6014891], [0., 0.7527783],
     [0., 0.38179383], [0., -0.9000931], [0., 0.4963313], [0., 0.45199597],
     [0., -0.9612661], [0., -0.30446827], [0., 0.9946457], [0., 0.14735897],
     [0., 0.24672022], [0., -0.20646505], [0., -0.20464632], [0., -0.1837264],
     [0., 0.8170703], [0., -0.15778475], [0., 0.5018849], [0., -0.8932749],
     [0., 0.10564396], [0., 0.91577905], [0., -0.01685368], [0., -0.42444932],
     [0., -0.30220333], [0., -0.46014422], [0., -0.99977124], [0., 0.06633057],
     [0., 0.15677923], [0., -0.46890667], [0., -0.36896873], [0., -0.6692916],
     [0., -0.17164145], [0., 0.756285], [0., -0.16595599], [0., 0.817191],
     [0., 0.5016242], [0., 0.3275893], [0., 0.50775236], [0., 0.02977822],
     [0., -0.10421295], [0., -0.9683575], [0., -0.6603392], [0., -0.1653904]],
    dtype=np.float32)
data_x = data[:, 0:1]
data_y = data[:, 1:2]


def loss_func(x, y):
    return tf.reduce_mean(tf.norm(func(x) - y, axis=1) / tf.norm(y, axis=1))


class MyNN(tf.keras.Model):

    def __init__(self):
        super().__init__()

        self.input_dims = [1, 1]
        self.func = func
        self.optimizer = optimizer

        self.net1 = tf.keras.layers.Dense(
            **{"units": 4, "activation": 'relu',
               "kernel_initializer": {
                   'class_name': 'glorot_uniform',
                   'config': {'seed': 1}}}
        )
        self.net2 = tf.keras.layers.Dense(
            **{"units": 1, "activation": None,
               "kernel_initializer": {
                   'class_name': 'glorot_uniform',
                   'config': {'seed': 1}}}
        )

    def train_one_step(self, x, y):
        with tf.GradientTape() as tape:
            x_pred = self(x, y)
            loss = loss_func(x_pred, y)
        grads = tape.gradient(loss, self.trainable_variables)

        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return loss

    def train(self, start_time=time.time(), max_iter=3000):
        for it in range(max_iter):
            loss = self.train_one_step(data_x, data_y)

            print("Iter %4d: loss %14.8e,  runtime: %8.2f"
                  % (it, loss.numpy(), time.time() - start_time))

    def call(self, x, y):
        r = y - self.func(x)
        g = self.net2(self.net1(r)) * 2e-3
        return x + g


model = MyNN()
model.train()
@amahendrakar
Copy link
Contributor

Was able to reproduce the issue with TF v2.3 and TF-nightly. Please find the gist of it here. Thanks!

@amahendrakar amahendrakar added comp:keras Keras related issues TF 2.3 Issues related to TF 2.3 type:support Support issues and removed type:bug Bug labels Sep 2, 2020
@bhack
Copy link
Contributor
bhack commented Sep 2, 2020

@j7168908jx Can you try with this formulation:

import numpy as np
import tensorflow as tf
import time
np.random.seed(1)
tf.keras.backend.set_floatx('float32')

func = lambda x: x
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-1)

data = np.array(
    [[0., 0.07179281], [0., 0.44064897], [0., 0.7666122], [0., -0.655319],
     [0., -0.28546047], [0., 0.8460491], [0., 0.14823522], [0., -0.14381762],
     [0., 0.7200559], [0., -0.92189044], [0., 0.37300184], [0., -0.525946],
     [0., 0.07766213], [0., 0.370439], [0., 0.17311008], [0., 0.88918954],
     [0., -0.5910955], [0., -0.947578], [0., -0.7192261], [0., 0.5109261],
     [0., 0.85887444], [0., -0.75145805], [0., 0.89897853], [0., 0.23428982],
     [0., 0.5785587], [0., 0.0541162], [0., 0.97772217], [0., 0.24339144],
     [0., -0.72505057], [0., -0.39533487], [0., 0.6692513], [0., -0.7257285],
     [0., 0.93652314], [0., 0.17861107], [0., 0.38464522], [0., 0.38880032],
     [0., -0.73994285], [0., -0.9602397], [0., 0.07763347], [0., 0.6147826],
     [0., 0.68406177], [0., 0.39951673], [0., -0.17188802], [0., -0.10017573],
     [0., 0.7917724], [0., 0.35767105], [0., 0.7892133], [0., -0.62747955],
     [0., 0.7562349], [0., -0.16161098], [0., -0.77050805], [0., 0.8068038],
     [0., -0.37315163], [0., -0.3467102], [0., -0.70654285], [0., -0.8679997],
     [0., 0.5002886], [0., -0.7214473], [0., 0.7718842], [0., -0.5767438],
     [0., 0.8550172], [0., 0.4230495], [0., -0.7064882], [0., 0.11737966],
     [0., 0.326883], [0., -0.439112], [0., -0.99425936], [0., -0.94338703],
     [0., -0.8153228], [0., 0.8651909], [0., -0.96342343], [0., 0.9296801],
     [0., -0.50757784], [0., 0.24734442], [0., 0.80675906], [0., 0.38375422],
     [0., -0.7953311], [0., -0.4127717], [0., 0.39363632], [0., -0.30887854],
     [0., -0.8299116], [0., -0.603797], [0., -0.9452248], [0., -0.80330634],
     [0., 0.34093502], [0., -0.793548], [0., 0.6014891], [0., 0.7527783],
     [0., 0.38179383], [0., -0.9000931], [0., 0.4963313], [0., 0.45199597],
     [0., -0.9612661], [0., -0.30446827], [0., 0.9946457], [0., 0.14735897],
     [0., 0.24672022], [0., -0.20646505], [0., -0.20464632], [0., -0.1837264],
     [0., 0.8170703], [0., -0.15778475], [0., 0.5018849], [0., -0.8932749],
     [0., 0.10564396], [0., 0.91577905], [0., -0.01685368], [0., -0.42444932],
     [0., -0.30220333], [0., -0.46014422], [0., -0.99977124], [0., 0.06633057],
     [0., 0.15677923], [0., -0.46890667], [0., -0.36896873], [0., -0.6692916],
     [0., -0.17164145], [0., 0.756285], [0., -0.16595599], [0., 0.817191],
     [0., 0.5016242], [0., 0.3275893], [0., 0.50775236], [0., 0.02977822],
     [0., -0.10421295], [0., -0.9683575], [0., -0.6603392], [0., -0.1653904]],
    dtype=np.float32)
data_x = data[:, 0:1]
data_y = data[:, 1:2]

def loss(model, x, y, training):
  # training=training is needed only if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  y_ = model(x,y, training=training)

  return loss_func(y, y_)


def loss_func(x, y):
    return tf.reduce_mean(tf.norm(func(x) - y, axis=1) / tf.norm(y, axis=1))


class MyNN(tf.keras.Model):

    def __init__(self):
        super().__init__()

        self.input_dims = [1, 1]
        self.func = func
        self.optimizer = optimizer

        self.net1 = tf.keras.layers.Dense(
            **{"units": 4, "activation": 'relu',
               "kernel_initializer": {
                   'class_name': 'glorot_uniform',
                   'config': {'seed': 1}}}
        )
        self.net2 = tf.keras.layers.Dense(
            **{"units": 1, "activation": None,
               "kernel_initializer": {
                   'class_name': 'glorot_uniform',
                   'config': {'seed': 1}}}
        )
    def grad(self,inputs, targets):
      with tf.GradientTape() as tape:
        loss_value = loss(self, inputs, targets, training=True)
      return loss_value, tape.gradient(loss_value, model.trainable_variables)

    def train_one_step(self, x, y):
      loss_value, grads = self.grad(x, y)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      return loss_value


    def train(self, start_time=time.time(), max_iter=3000):
        for it in range(max_iter):
            loss = self.train_one_step(data_x, data_y)
            print("Iter %4d: loss %14.8e,  runtime: %8.2f" % (it, loss.numpy(), time.time() - start_time))

    def call(self, x, y):
        r = y - self.func(x)
        g = self.net2(self.net1(r)) * 2e-3
        return x + g


model = MyNN()
model.train()

@j7168908jx
Copy link
Author

@bhack I tried this formulation, and found that the loss value is different from what i had. Also, the nan issue still remains

Iter 0: loss 4.19011816e+03, runtime: 0.18
Iter 1: loss 8.31127930e+02, runtime: 0.19
Iter 2: loss 5.12066895e+02, runtime: 0.20
Iter 3: loss 3.87902283e+02, runtime: 0.21

Iter 7848: loss 5.08754015e-01, runtime: 107.28
Iter 7849: loss 5.08563876e-01, runtime: 107.30
Iter 7850: loss nan, runtime: 107.31
Iter 7851: loss nan, runtime: 107.32

@bhack
Copy link
Contributor
bhack commented Sep 4, 2020

@j7168908jx Other then small cosmetic refactoring as you can see I've just inverted predict and true order in loss_func(y, y_) just to check if there was a problem between the nominator and denominator in your loss.
I've runnned that code code on colab with your original example max_iter=3000 but now seems that you have expanded the number of iteration.
I think that you can double check some behavior of the gradient in your loss as you are using tf.norm a #12071

@bhack
Copy link
Contributor
bhack commented Sep 4, 2020

/cc @rmlarsen

@YingzhouLi
Copy link

@bhack, the denominator in the loss is the norm of true $y$ values, which are all away from zero. Hence, the calculation of the denominator should be just a variable of batch sampling and is independent of any trainable variables. Unless something wired is happening in the back propagation, otherwise, the denominator should not be any place near zero.

@bhack
Copy link
Contributor
bhack commented Sep 7, 2020

Can you check the impact of the numerator norm?

@YingzhouLi
Copy link

After removing the denominator, the nan appears even earlier. See gist.

@bhack
Copy link
Contributor
bhack commented Sep 8, 2020

My previous comment was about numeratore.

@YingzhouLi
Copy link
YingzhouLi commented Sep 8, 2020

I have implement norm by tf functions and also remove the sqrt from the implementation. The nan issue is indeed caused by the tf.sqrt function. Likely, the function is nearly zero on a batch. (the last point is not verified) See gist.

loss function issue
norm(.) nan at 618 iter
square(norm(.)) no nan
sqrt(sum(square(.))) nan at 618 iter
sum(square(.)) no nan

@gowthamkpr gowthamkpr added type:bug Bug comp:ops OPs related issues and removed type:support Support issues labels Sep 18, 2020
@gowthamkpr gowthamkpr assigned rmlarsen and unassigned gowthamkpr Sep 18, 2020
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Sep 18, 2020
@sachinprasadhs
Copy link
Contributor

Was able to reproduce your issue in Tensorflow 2.5 , please find the gist here. Thanks!

@mohantym
Copy link
Contributor

Hi @j7168908jx ! I was able to resolve this issue after squaring the denominator in tf.reduce.mean function as suggested in this comment . Attaching gist for reference. Can we move this issue to closed status now?

def loss_func(x, y):
    return tf.reduce_mean(tf.norm(func(x) - y, axis=1) / tf.math.square(tf.norm(y, axis=1)))

@mohantym mohantym self-assigned this Mar 18, 2022
@mohantym mohantym added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Mar 18, 2022
@j7168908jx
Copy link
Author

Thanks for all comments above! I'd like to close it since squaring does resolve this problem for now.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues comp:ops OPs related issues stat:awaiting response Status - Awaiting response from author TF 2.3 Issues related to TF 2.3 type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants