[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low performance when using persistent mode GradientTape with LSTM layers #35928

Closed
zhuzhi-fairy opened this issue Jan 16, 2020 · 11 comments
Closed
Assignees
Labels
comp:ops OPs related issues TF 2.1 for tracking issues in 2.1 release type:performance Performance Issue

Comments

@zhuzhi-fairy
Copy link
zhuzhi-fairy commented Jan 16, 2020

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code: Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
  • TensorFlow installed from (source or binary): Conda
  • TensorFlow version (use command below): 2.0.0 and 2.1.0
  • Python version: 3.7.5
  • CUDA/cuDNN version: 10.0.130 / 7.6.4
  • GPU model and memory: GTX 980 Ti, GTX 2080 Ti

Describe the current behavior
The performance was very low when using persistent mode tf.GradientTape or create multi-GradientTape objects in one with block.
This phenomenon only happens when the model includes a LSTM layer.

Code to reproduce the issue
Here is a sample code to reproduce the problem.

import timeit
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

@tf.function
def train_step_0():
  with tf.GradientTape() as tape0, tf.GradientTape() as tape1:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    y_p1 = model1(tf.random.normal((256, 128)))
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape0.gradient(loss0, model0.trainable_variables)
  grad1 = tape1.gradient(loss1, model1.trainable_variables)

train_step_0()
print(timeit.timeit('train_step_0()', globals=globals(), number=100))

In this case, the GPU-Util was only about 30 % and the time was 37 seconds.
The following message was shown in screen.

2020-01-16 17:16:22.519978: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] function_optimizer failed: Invalid argument: Input 0 of node sequential/lstm/zeros_like was passed int32 from sequential/lstm/StatefulPartitionedCall:9 incompatible with expected variant.
2020-01-16 17:16:22.667364: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] function_optimizer failed: Invalid argument: Input 0 of node sequential/lstm/zeros_like was passed int32 from sequential/lstm/StatefulPartitionedCall:9 incompatible with expected variant.
2020-01-16 17:16:22.739393: W tensorflow/core/common_runtime/process_function_library_runtime.cc:675] Ignoring multi-device function optimization failure: Invalid argument: Input 0 of node sequential/lstm/zeros_like was passed int32 from sequential/lstm/StatefulPartitionedCall:9 incompatible with expected variant.
@tf.function
def train_step_1():
  with tf.GradientTape(persistent=True) as tape:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    y_p1 = model1(tf.random.normal((256, 128)))
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape.gradient(loss0, model0.trainable_variables)
  grad1 = tape.gradient(loss1, model1.trainable_variables)

train_step_1()
print(timeit.timeit('train_step_1()', globals=globals(), number=100))

In this case, the performance was similar to the previous one.

@tf.function
def train_step_2():
  with tf.GradientTape() as tape0:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
  with tf.GradientTape() as tape1:
    y_p1 = model1(tf.random.normal((256, 128)))
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape0.gradient(loss0, model0.trainable_variables)
  grad1 = tape1.gradient(loss1, model1.trainable_variables)

train_step_2()
print(timeit.timeit('train_step_2()', globals=globals(), number=100))

In this case, the GPU-Util was almost 100 % and the time was only 4 seconds.

In my opinion, train_step_0, train_step_1 and train_step_2 should have similar performance.
I am wandering why the GPU-Util and process time were so that different.
The strangest thing is that this phenomenon only happens when the model includes a LSTM layer.
If we exchange the LSTM layer to Conv or Dense layer, the process time will be all same.
Here is a colaboratory page to reproduce this problem.
https://colab.research.google.com/drive/1sluVFuW1yYtH0Ye4reoOEmLUHYHDSGn7

@ravikyram ravikyram self-assigned this Jan 17, 2020
@ravikyram ravikyram added comp:ops OPs related issues TF 2.1 for tracking issues in 2.1 release type:bug Bug labels Jan 17, 2020
@ravikyram ravikyram assigned ymodak and unassigned ravikyram Jan 17, 2020
@ravikyram ravikyram added type:performance Performance Issue and removed type:bug Bug labels Jan 17, 2020
@allenlavoie
Copy link
Member

There's some thought, or at least hope, that 3a03164 fixes the issue. Would you mind waiting until that's in a nightly (presumably tomorrow) and trying again?

If that fixes it then the performance issue was that this dtype mismatch meant that Grappler wasn't choosing the CuDNN LSTM implementation, and presumably the alternative was much slower.

@Andreas5739738
Copy link

I can trigger the same warning messages and performance degradation by removing the @tf.function decorator from the GradientTape (only decorating its body) Example here with TF 2.0.0:

import timeit
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)


@tf.function
def body0():
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
    return loss0

def train_step_3():
  with tf.GradientTape() as tape0:
    loss0 = body0()
  grad0 = tape0.gradient(loss0, model0.trainable_variables)


train_step_3()
print(timeit.timeit('train_step_3()', globals=globals(), number=100))

Output:

2020-03-11 15:39:48.084997: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] function_optimizer failed: Invalid argument: Input 0 of node sequential_lstm_statefulpartitionedcall_28_RetVal was passed int32 from sequential/lstm/StatefulPartitionedCall:29 incompatible with expected variant.
2020-03-11 15:39:48.138852: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] function_optimizer failed: Invalid argument: Input 0 of node sequential_lstm_statefulpartitionedcall_8_RetVal was passed int32 from sequential/lstm/StatefulPartitionedCall:9 incompatible with expected variant.
2020-03-11 15:39:48.168850: W tensorflow/core/common_runtime/process_function_library_runtime.cc:675] Ignoring multi-device function optimization failure: Invalid argument: Input 0 of node sequential_lstm_statefulpartitionedcall_8_RetVal was passed int32 from sequential/lstm/StatefulPartitionedCall:9 incompatible with expected variant.
2020-03-11 15:39:48.726501: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] implementation_selector failed: Invalid argument: Invalid format of input node name:  Expected: {forward_node_name}:{index}
2020-03-11 15:39:48.795141: W tensorflow/core/grappler/optimizers/implementation_selector.cc:310] Skipping optimization due to error while loading function libraries: Invalid argument: Functions '__inference___backward_standard_lstm_1162_1633' and '__inference___backward_standard_lstm_1162_1633_specialized_for_gradients_sequential_lstm_StatefulPartitionedCall_grad_StatefulPartitionedCall_at___inference___backward_body0_1074_1694' both implement 'lstm_bcb537ae-8579-4358-b3be-f7945b56523d' but their signatures do not match.
40.24392790393904

After adding the @tf.function decorator to the outside of gradienttape I do not get the warnings and it is faster:

import timeit
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)


@tf.function
def body0():
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
    return loss0

@tf.function
def train_step_3():
  with tf.GradientTape() as tape0:
    loss0 = body0()
  grad0 = tape0.gradient(loss0, model0.trainable_variables)


train_step_3()
print(timeit.timeit('train_step_3()', globals=globals(), number=100))

Output:

2020-03-11 15:41:23.242231: W tensorflow/core/grappler/optimizers/implementation_selector.cc:310] Skipping optimization due to error while loading function libraries: Invalid argument: Functions '__inference___backward_cudnn_lstm_with_fallback_855_1035' and '__inference___backward_cudnn_lstm_with_fallback_855_1035_specialized_for_StatefulPartitionedCall_1_gradients_sequential_lstm_StatefulPartitionedCall_grad_StatefulPartitionedCall_at___inference_train_step_3_1756' both implement 'lstm_a139ee92-2e42-4749-93b3-0b4af6926e59' but their signatures do not match.
8.711773838382214

@allenlavoie
Copy link
Member

@Andreas5739738, would you mind trying with the latest nightly? 2.2.0.dev20200311 should have 3a03164 which we hope fixes both the Grappler messages and performance.

@Andreas5739738
Copy link

I can still see the same warning messages with tf-nightly:
https://colab.research.google.com/gist/Andreas5739738/c3ec84dba75adb6880124c840c0be12f/untitled0.ipynb

@allenlavoie
Copy link
Member

I don't see the Grappler messages in the colab. Which messages?

@Andreas5739738
Copy link

If you run the code again and go to "Runtime" → "show logs", you will see them (they don't seem to be stored across runs).

@allenlavoie
Copy link
Member

Ah, I see. @rmlarsen any ideas? Those warnings do explicitly mention the implementation selector:

Mar 11, 2020, 10:39:03 AM WARNING 2020-03-11 17:39:03.239186: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:561] implementation_selector failed: Invalid argument: Invalid format of input node name: Expected: {forward_node_name}:{index}
Mar 11, 2020, 10:39:01 AM WARNING 2020-03-11 17:39:01.391898: W tensorflow/core/common_runtime/process_function_library_runtime.cc:697] Ignoring multi-device function optimization failure: Invalid argument: Input 0 of node sequential_lstm_statefulpartitionedcall_8_RetVal was passed float from sequential/lstm/StatefulPartitionedCall:9 incompatible with expected variant.
Mar 11, 2020, 10:39:01 AM WARNING 2020-03-11 17:39:01.383694: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:561] function_optimizer failed: Invalid argument: Input 0 of node sequential_lstm_statefulpartitionedcall_10_RetVal was passed float from sequential/lstm/StatefulPartitionedCall:11 incompatible with expected resource.
Mar 11, 2020, 10:39:01 AM WARNING 2020-03-11 17:39:01.350266: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:561] function_optimizer failed: Invalid argument: Input 0 of node sequential_lstm_statefulpartitionedcall_10_RetVal was passed float from sequential/lstm/StatefulPartitionedCall:11 incompatible with expected resource.

@saikumarchalla
Copy link

@zhuzhi-fairy Could you please check with latest TF version and let us know if the issue stil persists. Thanks!.

@zhuzhi-fairy
Copy link
Author

@zhuzhi-fairy Could you please check with latest TF version and let us know if the issue stil persists. Thanks!.

I have checked the codes with tf 2.2.0, and the problem has been solved, I think.
Thank you very much.

@Andreas5739738
Copy link
Andreas5739738 commented May 18, 2020

@saikumarchalla I can still reproduce the problem with tf_nightly-2.3.0.dev20200518-cp36-cp36m-manylinux2010_x86_64.whl and the colab linked above (run & view session logs):

2020-05-18 13:07:33.807689: W tensorflow/core/common_runtime/process_function_library_runtime.cc:773] Ignoring multi-device function optimization failure: Invalid argument: Input 0 of node sequential_lstm_partitionedcall_5_RetVal was passed float from sequential/lstm/PartitionedCall:6 incompatible with expected variant.

@AlexFuster
Copy link
AlexFuster commented May 18, 2021

I am having the same error in version 2.4.1
OP's code no longer replicates the eror. Instead, it appears when you add an optimizer

import time
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()

@tf.function
def train_step_0():
  with tf.GradientTape() as tape0, tf.GradientTape() as tape1:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    y_p1 = model1(tf.random.normal((256, 128)))
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape0.gradient(loss0, model0.trainable_variables)
  grad1 = tape1.gradient(loss1, model1.trainable_variables)
  optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
  optimizer.apply_gradients(zip(grad1,model1.trainable_variables))

t0=time.time()
for i in range(100):
  train_step_0()
print(time.time()-t0)
import time
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()

@tf.function
def train_step_1():
  with tf.GradientTape(persistent=True) as tape:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    y_p1 = model1(tf.random.normal((256, 128)))
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape.gradient(loss0, model0.trainable_variables)
  grad1 = tape.gradient(loss1, model1.trainable_variables)
  optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
  optimizer.apply_gradients(zip(grad1,model1.trainable_variables))

t0=time.time()
for i in range(100):
  train_step_1()
print(time.time()-t0)
import time
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()

@tf.function
def train_step_2():
  with tf.GradientTape() as tape0:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
  with tf.GradientTape() as tape1:
    y_p1 = model1(tf.random.normal((256, 128)))
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape0.gradient(loss0, model0.trainable_variables)
  grad1 = tape1.gradient(loss1, model1.trainable_variables)
  optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
  optimizer.apply_gradients(zip(grad1,model1.trainable_variables))

t0=time.time()
for i in range(100):
  train_step_2()
print(time.time()-t0)

Both train_step_0 and train_step_1 show the error, while train_step_2 doesn't. In my GPU, the first 2 approaches take around 11.5s in doing 100 training steps, while the third one takes 3.5s.

Edit: I tried with Tensorflow 2.5 and the gap is even bigger

This issue should be reoppened @ymodak

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues TF 2.1 for tracking issues in 2.1 release type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

7 participants