[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.data.Dataset .map().batch() pattern is not matched to use fused implementation. #53572

Open
mcourteaux opened this issue Dec 29, 2021 · 10 comments
Assignees
Labels
comp:data tf.data related issues comp:tensorboard Tensorboard related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.7 Issues related to TF 2.7.0 type:bug Bug

Comments

@mcourteaux
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): source
  • TensorFlow version (use command below): v1.12.1-69264-g0cdf35562dc 2.9.0
  • Python version: 3.8.10
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 11.5 / 8.3
  • GPU model and memory: GTX1660 Ti

Describe the current behavior
combining tf.data.Dataset.map() with .batch() does not use the fused BatchAndMap implementation.

Describe the expected behavior
It does use the fused implementation. Currently, it's only possible to use the fused implementation when using the deprecated experimental.map_and_batch() transformation.

Contributing

  • Do you want to contribute a PR? (yes/no): no
  • Briefly describe your candidate solution(if contributing):

Standalone code to reproduce the issue

import os
import datetime
from tqdm import tqdm
import numpy as np

import tensorflow as tf
print('TF version', tf.__version__)

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), 'Physical GPUs,', len(logical_gpus), 'Logical GPUs')
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)


@tf.function
def do_stuff(wmat, tf_var):
    with tf.device("/gpu:0"):
        S = tf.constant(0.0)
        for i in tf.range(4):
            fi = tf.cast(i, dtype=tf.float32)
            A = tf.math.lgamma(tf.tanh(tf.matmul(wmat + fi, tf.transpose(wmat - fi, [0, 2, 1]))))
            S += tf.reduce_sum(A)
        error = tf.reduce_mean(tf_var)
        return error, S

exp_uuid = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

n_batches = 512


def gen():
    for i in range(n_batches):
        with tf.device("/cpu:0"): # Make sure it comes from CPU
            r = tf.ones((400,800))
        yield r

option_names = ['map().batch()', 'map_and_batch()']
for option in range(2):

    with tf.device("/cpu:0"):
        dataset = tf.data.Dataset.from_generator(gen, output_types=tf.float32)

        def my_identity(x):
            with tf.device("/cpu:0"):
                print("my_identity input:", x, x.device)
                y = tf.identity(x)
                print("my_identity output:", y, y.device)
                return y

        if option == 0:
            ## Option 0: map().batch()
            dataset = dataset.map(my_identity).batch(16)

        elif option == 1:
            ## Option 1: deprecated map_and_batch()
            dataset = dataset.apply(tf.data.experimental.map_and_batch(my_identity, 16))

    gpu_transform = tf.data.experimental.prefetch_to_device('/gpu:0', buffer_size=4)
    dataset = dataset.apply(gpu_transform)


    tf_var = tf.Variable(tf.zeros(3))
    adam = tf.keras.optimizers.Adam(1e-4)
    logpath = os.path.join('data', 'logs', 'pa_' + exp_uuid + '_' + option_names[option])

    tf.profiler.experimental.start(logpath)
    start = datetime.datetime.now()
    for b, wmat in tqdm(enumerate(dataset)):
        with tf.GradientTape() as tape:

            if b == 0:
                print('\n\n dataset element device', wmat.device)
                print('\n')

            # Do some calculations
            result = do_stuff(wmat, tf_var)

        grads = tape.gradient(result[0], [tf_var])
        adam.apply_gradients(zip(grads, [tf_var]))
    stop = datetime.datetime.now()
    tf.profiler.experimental.stop()

    print(f'\n\nOption {option_names[option]}\n===========================\n')
    print(logpath)
    print('Time lapsed=', stop - start)
    print("\n\n")

Other info / logs

Option 1:
image
Symptoms:

  • See the blocks Iterator::FlapMap and Iterator::BatchV2 stacked on top of each other.
  • The MemcpyH2D (selected, see the details panel) is comping from pagable memory, instead of pinned memory (which is what MapAndBatch does). Because of the source being pagable memory, it can't overlap with kernel computations.

Option 2:
image
Evidence:

  • The MapAndBatch block is used.
  • The MemcopyH2D comes from pinned memory (see details pane) and overlaps with kernel computations.

The whole deal about pinned memory is to allow parallel data upload and kernel computations. So the dataset needs to be produced into pinned host memory, which then can be uploaded asynchronously by the driver without an extra copy. See #43905 (comment) and #43905 (comment) and:

This is a follow up on #43905.

@mcourteaux mcourteaux added the type:bug Bug label Dec 29, 2021
@tilakrayal tilakrayal added TF 1.12 Issues related to TF 1.12 comp:data tf.data related issues labels Dec 30, 2021
@tilakrayal
Copy link
Contributor

@mcourteaux ,
We see that you are using tf version 1.12, 1.x is not actively supported, please update to latest stable tf v2.7 and let us know if you are using same issue.

@tilakrayal tilakrayal added the stat:awaiting response Status - Awaiting response from author label Dec 30, 2021
@mcourteaux
Copy link
Author
mcourteaux commented Dec 30, 2021

No, I'm most definitely not. This was a fresh build from master branch from yesterday. Idk why the script that gives the TF version gives 1.12. It's most definitely wrong. I moved to TF 2 years ago. Note that that is the GIT_VERSION. Instead, tf.version.VERSION gives 2.9.0.

@tilakrayal
Copy link
Contributor

@mcourteaux ,
I was able to execute the code without any issues.Please find the gist of it here.Please provide the error log and also confirm if anything is missing here.Thanks!

@tilakrayal tilakrayal added TF 2.7 Issues related to TF 2.7.0 stat:awaiting response Status - Awaiting response from author and removed stat:awaiting response Status - Awaiting response from author TF 1.12 Issues related to TF 1.12 labels Dec 31, 2021
@mcourteaux
Copy link
Author
mcourteaux commented Jan 1, 2022

First a little frustration: @tilakrayal I get the sensation you are not paying attention. How does Google hope to get contributions to a project if all of the useful feedback is dismissed as being either wrong or nobody paying attention? It's frustrating that I lost around 2 hours identifying this problem, and then one more hour making a nice MWE that demonstrates the problem cleanly. Compare my three hour effort, to flow of this issue... I'll tag people who know what's going on: @jsimsa, @aaudiber.


You somehow managed to make a notebook with the code from the linked issue, not mine. There is no error message, with my code. I showed you how the MapAndBatchDataset implementation is only used when using tf.data.experimental.map_and_batch(), and not with map().batch() as is actually promised. To see this, you look at the performance trace in TensorBoard, not in a log or error message.

Happy new year!

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jan 3, 2022
@tilakrayal tilakrayal added the comp:tensorboard Tensorboard related issues label Jan 3, 2022
@tilakrayal tilakrayal assigned Saduf2019 and unassigned tilakrayal Jan 3, 2022
@jsimsa
Copy link
Contributor
jsimsa commented Jan 3, 2022

Hi @mcourteaux, thank you for the detailed repro and sorry for the initial response. I will have someone on the tf.data team take a closer look.

@Saduf2019 Saduf2019 assigned sachinprasadhs and unassigned Saduf2019 Jan 4, 2022
@sachinprasadhs sachinprasadhs added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 4, 2022
@jsimsa
Copy link
Contributor
jsimsa commented Jan 4, 2022

@wilsingosti is taking a look

@wilsingosti
Copy link
Contributor

The my_identity function is a no op. When option = 1, the Map is eliminated by the tf.data noop_elimination optimization. What is left is just Batch. If you replace the my_identify function with another that is not a no op, you should see the MapAndBatch implementation.

@jsimsa
Copy link
Contributor
jsimsa commented Jan 5, 2022

IIUC, the map in the repro above is introduced to make sure the memory allocated for the batch is pinned. From that perspective, the map is intended to be a no-op. To keep the map is and still trigger the fusion, it should be sufficient to disable the no-op elimination using tf.data options as follows:

dataset = ... # your dataset

options = tf.data.Options()
options.experimental_optimization.noop_elimination = False

dataset = dataset.with_options(dataset)

@mcourteaux
Copy link
Author

@wilsingosti Thanks for checking this. I'm wondering why not most Dataset implementations use the gpu_compatible memory when allocating stuff. Wouldn't that be useful in general, and then have us not rely on this MapAndBatch being selected to just copy the data to pinned memory.

@wilsingosti
Copy link
Contributor

Yes, it would be useful in general. AFAIK, it has just not been prioritized so far. I will try to do this for Batch dataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues comp:tensorboard Tensorboard related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.7 Issues related to TF 2.7.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests

7 participants