[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low performance when using persistent mode GradientTape with LSTM/GRU layers #51818

Open
AlexFuster opened this issue Sep 3, 2021 · 4 comments
Assignees
Labels
2.6.0 comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower type:performance Performance Issue

Comments

@AlexFuster
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.6.0
  • Python version: 3.8.10
  • CUDA/cuDNN version: 11.0/8.2.2
  • GPU model and memory: NVIDIA RTX Titan 24GB

Describe the current behavior
The performance was very low in graph mode when using persistent mode tf.GradientTape or create multi-GradientTape objects in one with block.
This phenomenon only happens when the model includes a LSTM or GRU layers.

Standalone code to reproduce the issue

import time
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()

@tf.function
def train_step_0():
  with tf.GradientTape() as tape0, tf.GradientTape() as tape1:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    y_p1 = model1(tf.random.normal((256, 128)))
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape0.gradient(loss0, model0.trainable_variables)
  grad1 = tape1.gradient(loss1, model1.trainable_variables)
  optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
  optimizer.apply_gradients(zip(grad1,model1.trainable_variables))

t0=time.time()
for i in range(100):
  train_step_0()
print(time.time()-t0)

output

2021-09-03 12:39:22.262342: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] function_optimizer failed: Invalid argument: Input 0 of node zeros_like_1 was passed float from sequential/lstm/PartitionedCall:6 incompatible with expected variant.
2021-09-03 12:39:22.287055: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] function_optimizer failed: Invalid argument: Input 0 of node zeros_like_1 was passed float from sequential/lstm/PartitionedCall:6 incompatible with expected variant.
2021-09-03 12:39:22.300477: W tensorflow/core/common_runtime/process_function_library_runtime.cc:841] Ignoring multi-device function optimization failure: Invalid argument: Input 0 of node zeros_like_1 was passed float from sequential/lstm/PartitionedCall:6 incompatible with expected variant.
17.58007049560547
import time
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()

@tf.function
def train_step_1():
  with tf.GradientTape(persistent=True) as tape:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    y_p1 = model1(tf.random.normal((256, 128)))
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape.gradient(loss0, model0.trainable_variables)
  grad1 = tape.gradient(loss1, model1.trainable_variables)
  optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
  optimizer.apply_gradients(zip(grad1,model1.trainable_variables))

t0=time.time()
for i in range(100):
  train_step_1()
print(time.time()-t0)

output

2021-09-03 12:43:44.947280: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] function_optimizer failed: Invalid argument: Input 0 of node zeros_like_1 was passed float from sequential/lstm/PartitionedCall:6 incompatible with expected variant.
2021-09-03 12:43:44.972180: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801] function_optimizer failed: Invalid argument: Input 0 of node zeros_like_1 was passed float from sequential/lstm/PartitionedCall:6 incompatible with expected variant.
2021-09-03 12:43:44.985573: W tensorflow/core/common_runtime/process_function_library_runtime.cc:841] Ignoring multi-device function optimization failure: Invalid argument: Input 0 of node zeros_like_1 was passed float from sequential/lstm/PartitionedCall:6 incompatible with expected variant.
16.632988929748535
import time
import numpy as np
import tensorflow as tf

model0 = tf.keras.models.Sequential(
    tf.keras.layers.LSTM(128, input_shape=(300, 40))
)
model1 = tf.keras.models.Sequential(
    tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(128,))
)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam()

@tf.function
def train_step_2():
  with tf.GradientTape() as tape0:
    f = model0(tf.random.normal([256, 300, 40]))
    y_p0 = model1(f)
    loss0 = loss_object(tf.zeros_like(y_p0), y_p0)
  with tf.GradientTape() as tape1:
    y_p1 = model1(tf.random.normal((256, 128)))
    loss1 = loss_object(tf.ones_like(y_p1), y_p1)
  grad0 = tape0.gradient(loss0, model0.trainable_variables)
  grad1 = tape1.gradient(loss1, model1.trainable_variables)
  optimizer.apply_gradients(zip(grad0,model0.trainable_variables))
  optimizer.apply_gradients(zip(grad1,model1.trainable_variables))

t0=time.time()
for i in range(100):
  train_step_2()
print(time.time()-t0)

output

4.321804523468018

Other info
Both train_step_0 and train_step_1 show the error, while train_step_2 doesn't. In my GPU, the first 2 approaches take around 17 in doing 100 training steps, while the third one takes 4.3s.
Furthermore, we can only reproduce this performace drop when using GRU/LSTMs in graph mode. Which is, if we remove the tf.function decorator from the train_step functions or if we switch the LSTM by a dense layer, all 3 examples take the same time and none of them outputs any error.
As an additional info, this problem happens running both in CPU and in GPU

By the way, this issue is an updated version of #35928 which addressed a very similar problem

@jvishnuvardhan
Copy link
Contributor

@AlexFuster Thanks for creating this issue. Looks like this is more related to keras-team/keras. So, I moved this issue to keras-team/keras repo for resolving. Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Sep 8, 2021
@AlexFuster
Copy link
Author

Well, Keras team doesn't seem to agree with that

@jvishnuvardhan
Copy link
Contributor

@AlexFuster looks like more related to TF core. So, this repo is right place for this issue. Thanks!

@rainwoodman
Copy link
Member

I noticed an interesting code flipping recurrent_v2._use_new_code() to False:

73b7097

This could explain why the performance degraded again.

I checked if I monkey patch a revert of the commit, the speed drastically improves:

from keras.layers import recurrent_v2
recurrent_v2._use_new_code = lambda : True

The change is by @yhliang2018 -- any backgrounds on why we have reverted to the old / slower code path?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2.6.0 comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

5 participants