[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.2] XLA requires 2x GPU memory with sparse_categorical_crossentropy #38675

Closed
lgeiger opened this issue Apr 18, 2020 · 6 comments
Closed

[2.2] XLA requires 2x GPU memory with sparse_categorical_crossentropy #38675

lgeiger opened this issue Apr 18, 2020 · 6 comments
Assignees
Labels
comp:xla XLA stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.2 Issues related to TF 2.2 type:performance Performance Issue

Comments

@lgeiger
Copy link
Contributor
lgeiger commented Apr 18, 2020

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.1.0, 2.2.0rc3 and tf-nightly
  • Python version: 3.7
  • CUDA/cuDNN version: 10.1 / 7.6.5.32
  • GPU model and memory: 4 x NVIDIA V100 on GCP

Describe the current behavior

I am currently trying to upgrade from TensorFlow 2.0.0 to 2.2.0rc3 and noticed a regression related to model training using XLA in a multi-GPU environment (GCP VM with 4 NVIDIA V100s).

The training code linked below runs comfortably with a huge batch size of 3072 on 4 GPUs using the normal TensorFlow runtime. However when enabling XLA with TF_XLA_FLAGS=--tf_xla_auto_jit=2 the same code runs out of GPU memory on the first batch of data. With XLA I can only use a maximum batch size of 1536 (50%) to prevent the code from running out of memory which doesn't seem right.

In which cases are the memory requirements of XLA and the default runtime similar?
To narrow down the possible causes for this I found a few cases where the maximum batch size for XLA and the normal runtime are the same:

  1. TensorFlow 2.0.0 doesn't seem to show this issue.

  2. Removing .prefetch(1) from the datapipline fixes the issue.

  3. Changing the training to one-hot encoded labels seems to fix the increase XLA memory requirements as well. To test this I changed the preprocessing and loss to:

    def preprocessing(data):
     return (
         tf.cast(_decode_and_center_crop(data["image"]), tf.float32),
         tf.cast(tf.one_hot(data["label"], 1000), tf.float32),
     )

    and

        model.compile(
         optimizer="adam",
         loss="categorical_crossentropy",
         metrics=["accuracy", "top_k_categorical_accuracy"],
     )

The above conditions suggest that the prefetched int32 labels and sparse categorical cross entropy might cause the regression with XLA, though I might miss something here. Any help would be very appreciated.

Describe the expected behavior

GPU memory requirements (messured here by maximum usable batch size) should be similar between XLA and the default runtime.

Standalone code to reproduce the issue

import tensorflow as tf
import tensorflow_datasets as tfds
import larq_zoo as lqz  # !pip install larq_zoo==1.0b4

batch_size = 3072


def _decode_and_center_crop(image_bytes):
    """Crops to center of image with padding then scales image_size."""
    shape = tf.image.extract_jpeg_shape(image_bytes)
    image_height = shape[0]
    image_width = shape[1]
    image_size = 224

    padded_center_crop_size = tf.cast(
        (
            (image_size / (image_size + 32))
            * tf.cast(tf.minimum(image_height, image_width), tf.float32)
        ),
        tf.int32,
    )

    offset_height = ((image_height - padded_center_crop_size) + 1) // 2
    offset_width = ((image_width - padded_center_crop_size) + 1) // 2
    crop_window = tf.stack(
        [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]
    )
    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
    return tf.image.resize(image, [image_size, image_size], method="bicubic")


def preprocessing(data):
    return (
        tf.cast(_decode_and_center_crop(data["image"]), tf.float32),
        data["label"],
    )


dataset = tfds.load(
    "imagenet2012:5.0.0",
    decoders={"image": tfds.decode.SkipDecoding()},
    split="train",
    data_dir="gs://my-data-bucket",
)

dataset = (
    dataset.map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .batch(batch_size)
    .prefetch(1)
)

with tf.distribute.MirroredStrategy().scope():
    model = lqz.sota.QuickNet(weights=None)

    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
    )

model.fit(dataset, epochs=5)

Other info / logs

I attached the XLA dumps below:
xla_dumps.tar.gz

@lgeiger lgeiger added the type:bug Bug label Apr 18, 2020
@amahendrakar amahendrakar added comp:xla XLA TF 2.1 for tracking issues in 2.1 release TF 2.2 Issues related to TF 2.2 type:performance Performance Issue and removed type:bug Bug labels Apr 20, 2020
@amahendrakar amahendrakar assigned ymodak and unassigned amahendrakar Apr 20, 2020
@cheshire
Copy link
Member

@lgeiger Thanks for looking into this!
XLA in general makes the memory fragmentation worse, but you are right it should not increase the memory consumption by that much.

We'll track this bug, but in general we have found that it is very difficult to deal with such problems when using autoclustering: so we try to use explicit compilation scopes with tf.function(experimental_compile=True) instead. If you could change the test case to use that, it would be very helpful (but I understand if it's not possible, e.g. here the code inside the lqz model would probably need to be annotated). Investing time into annotation could also make the possible performance impact more apparent though (by identifying a chunk in the profiler which is too slow and should be optimized, and adding an explicit annotation around that).

@cheshire cheshire assigned cheshire and unassigned ymodak Apr 20, 2020
@lgeiger
Copy link
Contributor Author
lgeiger commented Apr 20, 2020

@cheshire Thanks for taking a look, appreciate you help!
Unfortunately explicitely setting compilation scopes is currently not an option for us, since I am mostly dealing with research code where flexibility and readability is currently more important that using XLA to get the best possible performance.

I tried to narrow the problem a bit down though, and unfortunately after more testing is seems that even when removing prefetch(1) or using one hot encoded labels I see out of memory errors. So identical code might or might not run out of memory for seemingly arbitrarily reason. Does XLA autoclustering generate a deterministic graph, or can results differ from compilation to compilation?

I'll try to investigate further, but it's tricky to narrow it down.

@lgeiger
Copy link
Contributor Author
lgeiger commented Apr 20, 2020

For now it looks like the only reliable way for me to get this working is to either disable XLA autoclustering or to not use distributed training.

@mohantym
Copy link
Contributor
mohantym commented Dec 1, 2022

Hi @lgeiger !
We are checking to see whether you still need help in this issue .
I faced an "access denied error" to gcs bucket while replicating this issue.

But, Here are my pointers on this issue. It seems you have not followed XLA syntax.
Could you follow the XLA syntax (Auto-clustering) and let us know the results.

Attached gist with XLA syntax for reference.

Thank you!

@mohantym mohantym added the stat:awaiting response Status - Awaiting response from author label Dec 1, 2022
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Dec 8, 2022
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:xla XLA stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.2 Issues related to TF 2.2 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

6 participants