-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Low performance when using persistent mode GradientTape with LSTM layers #35928
Comments
There's some thought, or at least hope, that 3a03164 fixes the issue. Would you mind waiting until that's in a nightly (presumably tomorrow) and trying again? If that fixes it then the performance issue was that this dtype mismatch meant that Grappler wasn't choosing the CuDNN LSTM implementation, and presumably the alternative was much slower. |
I can trigger the same warning messages and performance degradation by removing the @tf.function decorator from the GradientTape (only decorating its body) Example here with TF 2.0.0:
Output:
After adding the @tf.function decorator to the outside of gradienttape I do not get the warnings and it is faster:
Output:
|
@Andreas5739738, would you mind trying with the latest nightly? 2.2.0.dev20200311 should have 3a03164 which we hope fixes both the Grappler messages and performance. |
I can still see the same warning messages with tf-nightly: |
I don't see the Grappler messages in the colab. Which messages? |
If you run the code again and go to "Runtime" → "show logs", you will see them (they don't seem to be stored across runs). |
Ah, I see. @rmlarsen any ideas? Those warnings do explicitly mention the implementation selector:
|
@zhuzhi-fairy Could you please check with latest TF version and let us know if the issue stil persists. Thanks!. |
I have checked the codes with tf 2.2.0, and the problem has been solved, I think. |
@saikumarchalla I can still reproduce the problem with tf_nightly-2.3.0.dev20200518-cp36-cp36m-manylinux2010_x86_64.whl and the colab linked above (run & view session logs):
|
I am having the same error in version 2.4.1 import time
import numpy as np
import tensorflow as tf
model0 = tf.keras.models.Sequential(
tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()
@tf.function
def train_step_0():
with tf.GradientTape() as tape0, tf.GradientTape() as tape1:
f = model0(tf.random.normal([256, 300, 40]))
y_p0 = model1(f)
y_p1 = model1(tf.random.normal((256, 128)))
loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
loss1 = loss_object(tf.ones_like(y_p1), y_p1)
grad0 = tape0.gradient(loss0, model0.trainable_variables)
grad1 = tape1.gradient(loss1, model1.trainable_variables)
optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
optimizer.apply_gradients(zip(grad1,model1.trainable_variables))
t0=time.time()
for i in range(100):
train_step_0()
print(time.time()-t0) import time
import numpy as np
import tensorflow as tf
model0 = tf.keras.models.Sequential(
tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()
@tf.function
def train_step_1():
with tf.GradientTape(persistent=True) as tape:
f = model0(tf.random.normal([256, 300, 40]))
y_p0 = model1(f)
y_p1 = model1(tf.random.normal((256, 128)))
loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
loss1 = loss_object(tf.ones_like(y_p1), y_p1)
grad0 = tape.gradient(loss0, model0.trainable_variables)
grad1 = tape.gradient(loss1, model1.trainable_variables)
optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
optimizer.apply_gradients(zip(grad1,model1.trainable_variables))
t0=time.time()
for i in range(100):
train_step_1()
print(time.time()-t0) import time
import numpy as np
import tensorflow as tf
model0 = tf.keras.models.Sequential(
tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()
@tf.function
def train_step_2():
with tf.GradientTape() as tape0:
f = model0(tf.random.normal([256, 300, 40]))
y_p0 = model1(f)
loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
with tf.GradientTape() as tape1:
y_p1 = model1(tf.random.normal((256, 128)))
loss1 = loss_object(tf.ones_like(y_p1), y_p1)
grad0 = tape0.gradient(loss0, model0.trainable_variables)
grad1 = tape1.gradient(loss1, model1.trainable_variables)
optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
optimizer.apply_gradients(zip(grad1,model1.trainable_variables))
t0=time.time()
for i in range(100):
train_step_2()
print(time.time()-t0) Both train_step_0 and train_step_1 show the error, while train_step_2 doesn't. In my GPU, the first 2 approaches take around 11.5s in doing 100 training steps, while the third one takes 3.5s. Edit: I tried with Tensorflow 2.5 and the gap is even bigger This issue should be reoppened @ymodak |
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
Describe the current behavior
The performance was very low when using persistent mode tf.GradientTape or create multi-GradientTape objects in one
with
block.This phenomenon only happens when the model includes a LSTM layer.
Code to reproduce the issue
Here is a sample code to reproduce the problem.
In this case, the GPU-Util was only about 30 % and the time was 37 seconds.
The following message was shown in screen.
In this case, the performance was similar to the previous one.
In this case, the GPU-Util was almost 100 % and the time was only 4 seconds.
In my opinion,
train_step_0
,train_step_1
andtrain_step_2
should have similar performance.I am wandering why the GPU-Util and process time were so that different.
The strangest thing is that this phenomenon only happens when the model includes a LSTM layer.
If we exchange the LSTM layer to Conv or Dense layer, the process time will be all same.
Here is a colaboratory page to reproduce this problem.
https://colab.research.google.com/drive/1sluVFuW1yYtH0Ye4reoOEmLUHYHDSGn7
The text was updated successfully, but these errors were encountered: