[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF 2.3 training slowed down by 15% compared to 2.2 #41827

Open
lgeiger opened this issue Jul 28, 2020 · 33 comments
Open

TF 2.3 training slowed down by 15% compared to 2.2 #41827

lgeiger opened this issue Jul 28, 2020 · 33 comments
Assignees
Labels
comp:keras Keras related issues TF 2.3 Issues related to TF 2.3 type:performance Performance Issue

Comments

@lgeiger
Copy link
Contributor
lgeiger commented Jul 28, 2020

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.3.0, 2.4.0.dev20200728
  • Python version: 3.7.8
  • CUDA/cuDNN version: 10.1 / 7.6.5.32
  • GPU model and memory: NVIDIA V100 on 12 vCPUs, 40 GB memory GCP node

Describe the current behavior

When upgrading from TensorFlow 2.2.0 to 2.3.0 we observed a 15 - 18% slow down in training speed for our workloads. Unfortunately I wasn't able to find an easy to reproduce example before the stable release was cut, but below is a code example that illustrates the performance degradation.

When running the training script on a single NVIDIA V100 a 15% performance loss compared to 2.2 can be observed which still is noticable in the latest nightly:

version epoch time step time GPU idle time
2.2.0 34 s 124.3 ms 19.7 ms (15.6 %)
2.3.0 39 s 141.9 ms 37.2 ms (26.1 %)
2.4.0.dev20200728 38s 136.2 ms 31.6 ms (23.2 %)

On Device: total self-time (grouped by type)

2.2.0 2.3.0 2.4.0.dev20200728
Screenshot 2020-07-28 at 17 45 02 Screenshot 2020-07-28 at 17 45 15 Screenshot 2020-07-28 at 17 46 08

The example uses auto mixed precision, but the slowdown can also be observed when running in float32 or in multi-GPU training. When looking at the generated execution profile the slowdown can be explained by an increased idle time of the GPU. Since the training data is cached in memory there should be no IO bottleneck so I am not sure if this performance regression is caused by tf.data or by the runtime itself.

Describe the expected behavior

TensorFlow 2.3 should show equally fast training performance compared to 2.2.

Standalone code to reproduce the issue

import tensorflow as tf
import tensorflow_datasets as tfds

batch_size = 64


def _decode_and_center_crop(image_bytes):
    """Crops to center of image with padding then scales image_size."""
    shape = tf.image.extract_jpeg_shape(image_bytes)
    image_height, image_width, image_size = shape[0], shape[1], 224

    padded_center_crop_size = tf.cast(
        (
            (image_size / (image_size + 32))
            * tf.cast(tf.minimum(image_height, image_width), tf.float32)
        ),
        tf.int32,
    )

    offset_height = ((image_height - padded_center_crop_size) + 1) // 2
    offset_width = ((image_width - padded_center_crop_size) + 1) // 2
    crop_window = tf.stack(
        [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]
    )
    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
    return tf.image.resize(image, [image_size, image_size], method="bicubic")


def preprocessing(data):
    return (
        tf.cast(_decode_and_center_crop(data["image"]), tf.float32),
        data["label"],
    )


dataset = tfds.load(
    "imagenette", decoders={"image": tfds.decode.SkipDecoding()}, split="train",
)

dataset = (
    dataset.cache()
    .repeat(2)  # Artificially increase time per epoch to make it easier to measure
    .map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .batch(batch_size)
    .prefetch(1)
)

with tf.distribute.MirroredStrategy().scope():
    model = tf.keras.applications.ResNet50(weights=None)

    model.compile(
        optimizer=tf.train.experimental.enable_mixed_precision_graph_rewrite(
            tf.keras.optimizers.Adam(), loss_scale="dynamic"
        ),
        loss="sparse_categorical_crossentropy",
    )

tb_cbk = tf.keras.callbacks.TensorBoard(f"logs/{tf.__version__}", profile_batch=300)
model.fit(dataset, verbose=2, epochs=3, callbacks=[tb_cbk])

Other info / logs

TensorBoard profiles for the runs mentioned above are available at tb-profile.zip

@mihaimaruseac @jsimsa @guptapriya do you mind taking a look at this?

@geetachavan1
Copy link
Contributor

I was able to reproduce the issue. Here is the gist..

@guptapriya
Copy link
Contributor

Thanks for the report @lgeiger . We will look more into this. The code sample uses MirroredStrategy - so I wanted to clarify when you said this regression is noticed on a single GPU as well - is that with or without MirroredStrategy? If latter, we would start investigating that case first (i.e. 1 GPU, no distribution, no mixed precision).

@lgeiger
Copy link
Contributor Author
lgeiger commented Jul 29, 2020

@guptapriya Thanks for looking into this.

The code sample uses MirroredStrategy - so I wanted to clarify when you said this regression is noticed on a single GPU as well - is that with or without MirroredStrategy?

Sorry about that, I missed that since I copied the example from a past issue I had with multi-GPU. I ran a few more benchmarks with the above example on a single GPU machine:

version FP32 FP32 Mirrored FP16 FP16 Mirrored
2.2.0 55 s 55 s 36s 34 s
2.3.0 54 s 58 s 37 s 39 s

Indeed, it looks like whether mirrored strategy is used or not has a large influence. I am not sure why there is a difference in execution speed for mixed precision with and without a strategy, although I think that might be a seperate issue.

One thing to note is that 2.3 logs the following deprecation warning when used with mirrored strategy which wasn't present before, but that might be unrelated as well:

WARNING:tensorflow:From ~/.local/lib/python3.7/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Note that this slowdown is more noticable with mixed precision as the kernel execution time is smaller so the increased idle time is easier to spot.

@amahendrakar amahendrakar added TF 2.3 Issues related to TF 2.3 type:performance Performance Issue comp:keras Keras related issues and removed type:bug Bug labels Jul 29, 2020
@guptapriya
Copy link
Contributor
guptapriya commented Jul 29, 2020

Thanks @lgeiger for the update, that helps a lot. I can confirm that we have verified the regression (looks like it happened sometime back in April)
We will update here when we have info into the root cause and fixes.

@lgeiger
Copy link
Contributor Author
lgeiger commented Jul 29, 2020

@guptapriya Thank you very much the fast response, I am glad that you were able to verify the regression.

I can confirm that we have verified the regression (looks like it happened sometime back in April)

April sounds like a long time for the regression to stay unnoticed. Am I doing something abnormal in the example or am I missing a best practice here?

@guptapriya
Copy link
Contributor

Actually now we are no longer sure about the timing because the benchmark I used to get the timing was actually changed in April. So we are now looking into the timing of regression again. It may have come up later.
Once we determine the root cause, we should be able to tell if there is anything in your code that makes this show up or simply on TF side that somehow went unnoticed.
In your experience, when did you start noticing the regression?

@lgeiger
Copy link
Contributor Author
lgeiger commented Jul 29, 2020

Makes sense, thanks for the help.

In your experience, when did you start noticing the regression?

I usually stay on the stable versions of TensorFlow for production workloads so I first noticed it around 2.3 RC 0 or 1 when running our internal ImageNet sanity check before upgrading TF version. But it took some time to find an reliably reproducible example that I could use to open the issue.

@zongweiz
Copy link
Contributor
zongweiz commented Jul 29, 2020

@lgeiger thanks for providing the tensorboard profile. We have some observations in your profiles which use fp16. You mentioned that the regression also happens on fp32 (and multi-gpu). Do you happen to have the fp32 profiles of TF 2.2 and 2.3? Thanks.

@lgeiger
Copy link
Contributor Author
lgeiger commented Jul 29, 2020

@zongweiz Sure, I reran the above code example (float32 with mirrored strategy) to generate these profiles:
tb-profile-fp32.zip

@lgeiger
Copy link
Contributor Author
lgeiger commented Aug 7, 2020

I hope the profiles are helpful for debugging, do you have any news on the resolution of this issue?

@guptapriya
Copy link
Contributor

@jaingaurav is looking into a potential fix I believe.

@jaingaurav
Copy link
Contributor

Yes, we found some unintended host to device copies caused by a previous change, that I am trying to eliminate.

@dathudeptrai
Copy link
dathudeptrai commented Aug 21, 2020

i also found that tensorflow 2.3 mixed precision didn't speed-up training as tensorflow 2.2.0 does :)). The only improvement on tensorflow2.3 i see is the less overhead time of multi-gpu training

@lgeiger
Copy link
Contributor Author
lgeiger commented Sep 2, 2020

Yes, we found some unintended host to device copies caused by a previous change, that I am trying to eliminate.

@jaingaurav Thanks for looking into it. Do you know if this fix will make it into the 2.4 release?

@jaingaurav
Copy link
Contributor

@lgeiger: Yes this is currently planned for the 2.4 release. However the fix is still being worked on.

@goldiegadde goldiegadde added this to To do in TensorFlow 2.4.0 via automation Sep 3, 2020
@lgeiger
Copy link
Contributor Author
lgeiger commented Nov 3, 2020

@jaingaurav I just checked with the new release candidate and CUDA 11, but unfortunately this issue still exists and the increased idle time is clearly visible in the profiles.

Below are the numbers for 2.4.0rc0 including the profiles for the runs:

version FP32 FP32 Mirrored FP16 FP16 Mirrored Keras FP16 Keras FP16 Mirrored
2.4.0rc0 54s 60s 40s 41s 35s 39s

@jaingaurav
Copy link
Contributor

cc: @rohan100jain, @zongweiz looks like the fix in 673b993 didn't quite work as well as we'd hoped.

@lgeiger
Copy link
Contributor Author
lgeiger commented Nov 3, 2020

Thanks for looking into it. I also updated the table with measurements using the new Keras mixed precision API which makes the slowdown easier to spot.

@lgeiger
Copy link
Contributor Author
lgeiger commented Nov 10, 2020

@rohan100jain Has there been any progress on this? From briefly skimming the changes in RC1, it looks like it doesn't include a fix for this.

This bug has been blocking me from upgrading to 2.3 so it would be great to get this resolved in 2.4. Please let me know if there anything that I could do from my side to help you debug this further?

@lgeiger
Copy link
Contributor Author
lgeiger commented Nov 10, 2020

I double checked and the issue still exists in tf-nightly 2.5.0-dev20201110.

@lgeiger
Copy link
Contributor Author
lgeiger commented Nov 17, 2020

@rohan100jain @zongweiz @jaingaurav I did a bit more testing with 2.4.0rc1 and it looks like when using .batch(batch_size, drop_remainder=True) instead of .batch(batch_size), both mirrored strategy and normal training behave the same in float32 and the slowdown is significantly reduced for mixed precision training.

I am a bit confused why keeping the remainder results in such a significant performance degradation, but I didn't look into it in detail yet. The mixed precision training runs are still show a high idle time, but slow down for mirrored strategy isn't as significant as before. Here are the TensorBoard profiles for the measurements shown below:

version FP32 FP32 Mirrored Keras FP16 Keras FP16 Mirrored
2.4.0rc1 (drop_remainder=True) 54s 54s 34s 35s

@zongweiz
Copy link
Contributor
zongweiz commented Nov 17, 2020

@lgeiger Thanks very much for your information. Yes, your observation matches what we found. We have tracked down the performance issue to a tf.cond inside Keras batch norm layer (which is to handle empty batches). Setting drop_remainder=True avoids empty/partial batches and works around the problem. @rohan100jain is working on a fix and will give an update very soon.

I think the reason we see more regression on FP16 is because: the above problem is more significant when the workload has higher overhead in launching GPU kernels, fp16 makes some compute kernels more efficient and makes the workload more kernel launch bound.

@lgeiger
Copy link
Contributor Author
lgeiger commented Nov 17, 2020

@zongweiz Thanks for looking into it. Looking forward to a fix.

@lgeiger
Copy link
Contributor Author
lgeiger commented Nov 23, 2020

@rohan100jain Do you have any updates whether a fix will make it into the 2.4 release?

@lgeiger
Copy link
Contributor Author
lgeiger commented Nov 30, 2020

@rohan100jain @zongweiz @jaingaurav Sorry for pinging you again. Do you have any updates on whether the fix will make it into the 2.4 stable release?

@guptapriya
Copy link
Contributor

Hey @lgeiger i believe the fix did not make it to 2.4 unfortunately. @rohan100jain @zongweiz @goldiegadde please correct if that is not the case.

@lgeiger
Copy link
Contributor Author
lgeiger commented Dec 2, 2020

Hey @lgeiger i believe the fix did not make it to 2.4 unfortunately.

Thanks for the response, that's really unfortunate since the regression has been there since 2.2. But at least we can now upgrade to 2.4 when enabling drop_remainder=True as a workaround until a fix lands.

@geetachavan1 geetachavan1 added this to To do in TensorFlow 2.5 Jan 6, 2021
@geetachavan1 geetachavan1 removed this from To do in TensorFlow 2.4.0 Jan 6, 2021
@lgeiger
Copy link
Contributor Author
lgeiger commented Feb 5, 2021

@guptapriya has there been any progress on this issue? It would be good if a fix would make it into TF 2.5 since in our current workloads running on 4 GPUs we are seeing slowdowns of 30-80% due to this bug compared to TF 2.2 (using XLA makes this regression even more dramatic).

@guptapriya
Copy link
Contributor

@lgeiger It looks like the last fix that was tried in November did not fix the issue and I don't see any other updates since then. Checking with @rohan100jain , will update when I know more.

@rohan100jain
Copy link
Member

Apologies but I tried fixing it last year and that change had to be rolled back / didn't do what we expected to. We'll continue to work on it and get a fix out by 2.5

@lgeiger
Copy link
Contributor Author
lgeiger commented Feb 12, 2021

@rohan100jain Thanks for the update. Let me know when a fix lands and I am can rerun the benchmarks to verify.

@rohan100jain
Copy link
Member

I'm sorry we looked into this issue and there isn't really any easy way of fixing this without rolling back a change (f0d0485) that enhances dtype coverage of our GPU ops and improves the consistency of Tensorflow in general. This issue has exposed some problems we need to fix with our device placement that we're planning to work on and will have an RFC for it. I'll therefore recommend that you continue to use the drop_remainder=True workaround for now.

@lgeiger
Copy link
Contributor Author
lgeiger commented Mar 31, 2021

Thanks for the update. I will continue using drop_remainder=True for now as a workaround. I hope there will be a fix for it soon since I think MirroredStrategy together with the default batching is a quite common use case for people running on GPUs.

It would be awesome if it is possible to add this example (or a similar one using MirroredStrategy and a large cached dataset) to your internal regression testing suite. In a lot of the TF version upgrades I have done in the past I discovered some sort of memory issue or performance regression that was reproducible with code very similar the example mentioned above (See #36240, #38617, #38655). It would be excellent if issues like that would be caught automatically so they don't make it into the stable releases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.3 Issues related to TF 2.3 type:performance Performance Issue
Projects
TensorFlow 2.5
  
To do
Development

No branches or pull requests

9 participants