[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrading from TF 2.1 to 2.2 gives 12% slowdown and 23% memory increase #39434

Open
jarednielsen opened this issue May 12, 2020 · 24 comments
Open
Assignees
Labels
comp:gpu GPU related issues comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.7 Issues related to TF 2.7.0 type:performance Performance Issue

Comments

@jarednielsen
Copy link
Contributor

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution: Ubuntu 18.04
  • TensorFlow version (use command below): 2.1.0, 2.2.0, tf-nightly
  • Python version: 3.6.9
  • CUDA/cuDNN version: 10.1
  • GPU model and memory: 8xTesla V100, 32GB each

Describe the current behavior

I'm running language modeling experiments with ALBERT, and GPU memory is at a premium due to the large batch sizes necessary. Upgrading from TF 2.1.0 to 2.2.0, I experienced OOM errors, so I ran a few benchmarks:

TF nightly (May 8) TF 2.2 TF 2.1
Iterations/Sec 1.63 1.64 1.86
GPU Memory 26 GB 25 GB 21 GB

That's a combination of 12% speed slowdown, and 23% memory increase. I cannot upgrade until performance is matched. Are there any new experimental options, or changes I should be aware of that caused this massive performance hit? I'm using tf.function, XLA, and AMP.

It seems that fewer and fewer ops are converted to mixed-precision as we progress from TF 2.1->2.2->nightly. Is that related, and how can I restore the original behavior?

Poring over the release notes, the only thing that sticks out is:
tf.constant always creates CPU tensors irrespective of the current device context.

@amahendrakar
Copy link
Contributor

@jarednielsen,
In order to expedite the trouble-shooting process, could you please provide the complete code to reproduce the issue reported here. Thanks!

@amahendrakar amahendrakar added comp:gpu GPU related issues stat:awaiting response Status - Awaiting response from author labels May 12, 2020
@jarednielsen
Copy link
Contributor Author
jarednielsen commented May 14, 2020

Sure, replication on this one is a bit complicated since it requires comparing across TF versions and with different config options. Run

pip install tensorflow==2.2.0
pip install transformers==2.9.1

Then with the following self-contained training script:

# performance.py
"""
Demonstrates a performance regression from TF 2.1 to 2.2 when using model.call(training=True).
Memory increases from from 11.7GB to 17.6GB and it/s decreases from 7.42 to 6.24.
Use the parameters --batch_size=16 --accum=1, and try with & without --training.
"""

import argparse
import time

import numpy as np
import tensorflow as tf
from tqdm import tqdm
from transformers import AlbertConfig, TFAlbertForPreTraining


def train_batch(input_dict, training: bool):
    with tf.GradientTape() as tape:
        mlm_logits, sop_logits = model(input_dict, training=training)
        loss = tf.reduce_mean(mlm_logits) + tf.reduce_mean(sop_logits)
        scaled_loss = opt.get_scaled_loss(loss)
    scaled_grads = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(opt.get_unscaled_gradients(scaled_grads), model.trainable_variables))


def benchmark_lm(num_epochs: int, training: bool):
    start_time = time.perf_counter()
    for i, batch in tqdm(enumerate(dataset.take(num_epochs))):
        train_batch(batch, training)
    tf.print(f"Execution time: {time.perf_counter() - start_time}")


def get_synthetic_mlm_unlabeled_dataset(batch_size: int) -> tf.data.Dataset:
    """ Returns a dataset that includes batching, but not gradient accumulation. """

    def gen(batch_size):
        shape = [batch_size, 512]
        input_ids = tf.constant(np.random.randint(10, size=shape), dtype=tf.int32)
        attention_mask = tf.constant(np.random.randint(2, size=shape), dtype=tf.int32)
        token_type_ids = tf.constant(np.random.randint(2, size=shape), dtype=tf.int32)

        input_dict = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        }

        yield input_dict

    dataset = tf.data.Dataset.from_generator(
        gen,
        output_types={
            "input_ids": tf.int32,
            "attention_mask": tf.int32,
            "token_type_ids": tf.int32,
        },
        args=(batch_size,),
    )
    dataset = dataset.repeat()
    return dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--eager", action="store_true")
    parser.add_argument("--batch_size", type=int)
    parser.add_argument("--training", action="store_true")
    args = parser.parse_args()

    eager = args.eager
    batch_size = args.batch_size
    training = args.training

    gpus = tf.config.list_physical_devices("GPU")
    tf.config.experimental.set_visible_devices(gpus[0], "GPU")
    tf.config.experimental.set_memory_growth(gpus[0], True)

    tf.config.optimizer.set_jit(True)  # XLA
    tf.config.experimental_run_functions_eagerly(eager)  # AutoGraph
    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})  # AMP

    train_batch = tf.function(train_batch)

    config = AlbertConfig.from_pretrained("albert-base-v2")
    model = TFAlbertForPreTraining(config)
    opt = tf.keras.optimizers.Adam(0.01)
    opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, loss_scale="dynamic")

    dataset = get_synthetic_mlm_unlabeled_dataset(batch_size)

    # Cache function
    train_batch(next(iter(dataset)), training)
    benchmark_lm(num_epochs=50, training=training)

Run the following:

python performance.py --batch_size=16
python performance.py --batch_size=16 --training

while monitoring GPU utilization with watch -n 1 nvidia-smi.
Then repeat the above, but with TF 2.1 installed.

The output on my Tesla V100 (32GB RAM) is

Simple Benchmark 2.1 it/s 2.2 it/s 2.1 memory 2.2 memory
training=False 7.13 7.09 11.7GB 11.7GB
training=True 7.12 6.43 11.7GB 17.6GB

Note that TF 2.1 performance is consistent with training=False and training=True, while TF 2.1 drops in speed by 10% and increases memory usage by 50%. The language model being benchmarked uses a dropout layer (with dropout prob = 0, so it's a no-op) and layer normalization, but not batch normalization.

@amahendrakar Are you able to reproduce the issue with the above code?

@amahendrakar amahendrakar added TF 2.1 for tracking issues in 2.1 release TF 2.2 Issues related to TF 2.2 and removed stat:awaiting response Status - Awaiting response from author labels May 14, 2020
@jarednielsen
Copy link
Contributor Author

I've benchmarked this effect with various model sizes, by changing the number of hidden layers. Happy to share more code if necessary. These are the results:

training=True 2.1 memory 2.2 memory
2 layers  3.2GB 3.2GB
4 layers  5.3GB 5.3GB
6 layers  5.3GB 9.4GB
8 layers  5.3GB 9.4GB
10 layers  9.4GB 9.4GB
12 layers  9.4GB 17.6GB

When training=False, the 2.2 memory results are identical to 2.1. Yet dropout probabilities are set to 0, so the value of training should be irrelevant.

It appears that the amount of memory required increases at discrete intervals, and training=True increases the amount of memory required by TF 2.2 (sometimes doubling it!).

@gowthamkpr Are you able to reproduce the issue?

@zongweiz
Copy link
Contributor

Memory usage wise, I am surprised that in TF 2.1, training=True consumes the same amount of memory as training=False, in keras dropout layer, training=False means that no dropout is done and thus no dropout mask needs to be preserved for backprop. I would assume training=True will always consume more memory, due to dropout (unless XLA optimize away the dropout mask)

Could you please run the model with only AMP (no XLA) on training=True/False and see how this change memory usage? Thanks.

@zongweiz
Copy link
Contributor

Also it would be better if you could also try out Keras mixed precision API
https://www.tensorflow.org/guide/keras/mixed_precision

@jarednielsen
Copy link
Contributor Author

The model is using dropout_prob=0, so there should be no difference between training=True and training=False.

I am using the Keras mixed-precision API, aside from tf.keras.mixed_precision.experimental.set_policy('mixed_float16'), which fails due to dtype mismatches:

tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})  # AMP
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, loss_scale="dynamic")
scaled_loss = opt.get_scaled_loss(loss)
scaled_grads = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(opt.get_unscaled_gradients(scaled_grads)

However, mixed-precision is not the issue because the same memory bug occurs when using full-precision. The memory bug also appears when not using XLA. The memory bug also occurs when not using XLA and not using AMP.

Are you able to run the code above? I tried to make it a minimal reproducible example.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 20, 2020
@sanjoy
Copy link
Contributor
sanjoy commented May 21, 2020

which fails due to dtype mismatches:

CC @reedwm

@jimmy6
Copy link
jimmy6 commented May 31, 2020

I also found memorry issue once i upgrade from GPU 2.0 to 2.2. I already reduce half of the data size. But still getting on model.fit. The Nvidia GPU will just disappear in Window.


~\Anaconda3\envs\Transformer_1D-CNN_Feature_Extraction_tf2.2\lib\site-packages\tensorflow\python\profiler\profiler_v2.py in stop(save)
    107     if save:
    108       try:
--> 109         _profiler.export_to_tb()
    110       except Exception:
    111         _profiler = None

MemoryError: bad allocation

Notebook is here

@jarednielsen
Copy link
Contributor Author

@sanjoy The mixed-precision dtype mismatch is a problem, but it's orthogonal to the performance regression. The performance regression occurs in full-precision as well. Have you replicated the issue?

@reedwm
Copy link
Member
reedwm commented Jun 2, 2020

It seems that fewer and fewer ops are converted to mixed-precision as we progress from TF 2.1->2.2->nightly. Is that related, and how can I restore the original behavior?

I can reproduce and also observe AMP is converting less nodes to float16. Assigning to myself.

@reedwm reedwm assigned reedwm and unassigned sanjoy Jun 2, 2020
@sanjoy
Copy link
Contributor
sanjoy commented Jun 2, 2020

@reedwm can you also take a look at the regression @jarednielsen mentioned above that seems unrelated to mixed precision?

@jarednielsen
Copy link
Contributor Author

@reedwm Any insight on the problem or how to sidestep it?

@dspyrhsu
Copy link

Actually, I believe that this might be an issue of TF2.2 in general, since I am experiencing an effect where a (very simple) transfer learning model works perfectly and fast with or without gpu on TF2.1 and completely fails on TF2.2 both with or without gpu. While it still trains the model, it takes 71s for the first epoch with TF 2.2, where it took 7s with TF2.1. And finally, on TF2.1, it takes one second for inference using the model on 20 previously unseen samples, where this process never finishes with TF2.2.

However, there might be something I missed, but I do find it strange that this just "happens" in a notebook where I did not change a line of code ...

Sorry for being verbose, but not posting the code, but I believe that this is not easily reproduced without the data. However, if there is more I could do to provide additional info, please let me know.

@mihaimaruseac
Copy link
Collaborator

Can you compare with 2.3 now that it's almost near release?

@dspyrhsu
Copy link
dspyrhsu commented Jul 24, 2020

Hi,

I am sorry to say that I am not sure how to install 2.3. If you mean installing tfp-nightly via pip, I did that just now (I am, however, using conda for everything else) an reran the notebook. I definitely see an improvement here, since now the first epoch of training has only a small overhead compared to the other epochs and those are very fast. However, I again had to interrupt the prediction step (which takes only one second using TF2.1 as mentioned above), since it wouldn't finish.

As I write this, I notice that the install command only installed the cpu-version of tensorflow - so I have to check the gpu-version as well concerning the problem with the first epoch.

Anyway, I did install tf-nightly-gpu just now (which gives a TF 2.4, if I see this correctly) and there it complained about the installed scipy version being too new (The Warning message spoke about Tensorflow 2.2.0 needing scipy 1.4.1 - so I wonder whether this might be the problem?). So I downgraded that and I can report that everything worked just fine with this setup (still only CPU, since I need to figure out a way to not have to install the whole CUDA stuff manually - conda does that for me normally).

Once I find the time to approach this more systematically, I will get back. And, before I forget, thanks for getting back to me.

@dspyrhsu
Copy link
dspyrhsu commented Jul 25, 2020

So, I will describe my experiments briefly, but the gist is that with tensorflow==2.3.0rc2 (as well as with the nighly of July 25th, 2020) inference works as it did/does with TF2.1 and the problem is NOT due to the scipy version.

Detailed Experiments

Now on to the detailed experiments, which all start with a fresh miniconda install and the creation of virtual environments for the setups. In all cases, I first do
conda install jupyterlab pandas numpy scipy scikit-learn matplotlib seaborn

First setup (TF-2.2):

Here I install TF as it comes with conda:

conda install tensorflow-gpu
When I train and test my model, I experience both the slow start of training (only when using GPU) as well as not being able to carry out the prediction step on the test data (both with GPU as well as CPU only).

Second setup (TF-2.1):

Here I install TF in a downgraded conda version:

conda install tensorflow-gpu==2.1.0
When I train and test my model, I experience a moderately slow start of training (only when using GPU), but everything else works just fine.

Third setup (TF-2.3rc2 / TF-nightly):

Here I first install TF as usual (to obtain cudnn, cupti, and cudatoolkit), only to remove it and replace it with the pip version. Also, I downgrade scipy from 1.5.0 to 1.4.1 in order not to get a warning from the pip install:

conda install tensorflow-gpu
conda uninstall --force tensorflow tensorflow-base tensorflow-gpu scipy
pip install scipy==1.4.1 tensorflow-gpu==2.3.0rc2

When I train and test my model, I experience a very slow start of training (only when using GPU), but everything else works just fine. The same is true if I replace tensorflow-gpu==2.3.0rc2 with tf-nightly-gpu.

Additional Information on the slow start of training

So, the problem with the slow start of training does NOT happen, if I use the CPU only, but it does happen when using the GPU (and it seems to have been reported before, where the error message has been reported here). It should be noted that with TF2.1, the error message just reads

Not found: ./bin/ptxas not found Relying on driver to perform ptx compilation. This message will be only logged once.
and then TF more or less just gets on with its business (first epoch takes ~7s, so there is a moderate holdup, but this is normal and also happens in further reduced form when using the CPU only) , while for all later versions there is a long holdup before training starts (first epoch takes between 35s - 71s).

Warning: You might want to refrain from trying next part ...

Now, it seems as if one can get around this problem by reading all of this Stackoverflow discussion and, immediately after installing TF via conda, doing a combination of two answers:
conda install -c conda-forge cudatoolkit-dev
This will update the standard cudatoolkit and provide both nvccas well as ptxas.

However, there are at least two caveats (and one Ooops - see below)

  • Please note that apparently this package can be installed in one virtual environment only. Since in order to be able to install it in the environment with TF-2.3rc2, I first had to remove the environment with the nightly where I had originally installed it first (maybe uninstalling the toolkit would have been enough, but I did not want to keep the nightly anyway).

  • Also, the virtual environment will have its own cudatoolkit in its pkgs directory. However, a lot of the shared library files in the lib64 directory have length zero, so, unsurprisingly, the GPU cannot be used. This can be remedied by copying over everything from the global cudatoolkit lib64 directory, but this still leaves quite a few zero length files, which might play some part somewhere. For running the training, however, everything seems to be ok.

And now to the really bad part - the Ooops: It seems that the installation of the cudatoolkit-dev messes with your system (somehow, although it was not done with root priviliges) and not just with miniconda! After installing it, nvidia-detector crashes. While I am not sure how serious this actually is (the system is still running and also works as expected after reboot, at least for now, afaict). So, if you do this, you do it at your own risk!

So, sorry for the long post, but maybe this actually answers @jarednielsen 's original question and also offers a way to sidestep the issue, albeit at a certain cost ...

@mihaimaruseac
Copy link
Collaborator

So, if I understand the comments correctly, TF 2.3 no longer has the slowdown and memory increase, right?

Thank you for the detailed response, it's been very helpful. Do you recall what the scipy warning was?

@dspyrhsu
Copy link

Hi, in fact, I did not check the memory increase, but there was no more slowdown with TF2.3rc2.

The scipy message was something along the lines that tensorflow requires scipy 1.4.1, while scipy 1.5 is installed and the latter was incompatible (which, in fact, I doubt).

I believe, however, that the cudatoolkit which is installed along with tensorflow, should be upgraded, so the dev-version does not have to be installed in order for tensorflow (or rather the possibleoptimizations) to work properly.

@mihaimaruseac
Copy link
Collaborator

Oh, the scipy message is from pypi/pip. It can safely be ignored

@jarednielsen
Copy link
Contributor Author
jarednielsen commented Aug 20, 2020

TF 2.3 still has the same speed and memory slowdowns as 2.2. Any update @sanjoy @reedwm @mihaimaruseac ?

@SivamPillai
Copy link

Also experienced the same issue while upgrading from TF 2.0 to TF 2.3. The slowdown was almost 100% compared to the older version. Is it possible to share some optimization tricks that can help improve model performance? I have also been getting the warning regarding 'batch operation faster than "on_train_batch_end" (not sure if that is related to the slowdown)

example (the actual times vary based on model size, batch size, validation set, etc.

WARNING:tensorflow:Callbacks method 'on_train_batch_end' is slow compared to the batch time (batch time: 0.0811s vs on_train_batch_end' time: 0.1891s). Check your callbacks.'

@laoyin
Copy link
laoyin commented Nov 17, 2020

tensorflow version== '2.5.0-dev20201109'
cuda 11
GPU RTX 3090

when i run the code over 7-8 hours , i would get this error------ "MemoryError: bad allocation"

Traceback (most recent call last): File "train_mspeech.py", line 53, in ms.TrainModel(datapath, epoch = 50, batch_size = 16, save_step = 500) File "D:\ASR_project\asr\SpeechModel251.py", line 187, in TrainModel self.TestModel(self.datapath, str_dataset='train', data_count = 4) File "D:\ASR_project\asr\SpeechModel251.py", line 250, in TestModel pre = self.Predict(data_input, data_input.shape[0] // 8) File "D:\ASR_project\asr\SpeechModel251.py", line 326, in Predict r1 = r[0][0].eval(session=tf.compat.v1.Session()) File "D:\ASR_project\asr\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 921, in eval return _eval_using_default_session(self, feed_dict, self.graph, session) File "D:\ASR_project\asr\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 5515, in _eval_using_default_session return session.run(tensors, feed_dict) File "D:\ASR_project\asr\venv\lib\site-packages\tensorflow\python\client\session.py", line 968, in run run_metadata_ptr) File "D:\ASR_project\asr\venv\lib\site-packages\tensorflow\python\client\session.py", line 1191, in _run feed_dict_tensor, options, run_metadata) File "D:\ASR_project\asr\venv\lib\site-packages\tensorflow\python\client\session.py", line 1369, in _do_run run_metadata) File "D:\ASR_project\asr\venv\lib\site-packages\tensorflow\python\client\session.py", line 1375, in _do_call return fn(*args) File "D:\ASR_project\asr\venv\lib\site-packages\tensorflow\python\client\session.py", line 1358, in _run_fn self._extend_graph() File "D:\ASR_project\asr\venv\lib\site-packages\tensorflow\python\client\session.py", line 1398, in _extend_graph tf_session.ExtendSession(self._session) MemoryError: bad allocation

@rmothukuru
Copy link
Contributor

Faced OOM Error upon running the code in this comment with Tensorflow Version 2.5. Please find the Gist.

@sushreebarsa sushreebarsa added TF 2.7 Issues related to TF 2.7.0 and removed TF 2.1 for tracking issues in 2.1 release TF 2.2 Issues related to TF 2.2 labels Dec 7, 2021
@tensorflowbutler
Copy link
Member

Hi There,

This is a stale issue. As you are using an older version of tensorflow, we are checking to see if you still need help on this issue. Please test the issue with the latest TensorFlow (TF2.7 and tf-nightly). If the issue still persists with the newer versions of TF, please feel free to open it in keras-team/keras repository by providing details about the issue and a standalone code to reproduce the issue. Thanks!

Please note that Keras development has moved to a separate Keras-team/keras repository to focus entirely on only Keras. Thanks!

@gowthamkpr gowthamkpr reopened this Dec 20, 2022
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Dec 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:gpu GPU related issues comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.7 Issues related to TF 2.7.0 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests