[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset.shuffle leads to worse training performance due to chunked processing #36626

Open
Flamefire opened this issue Feb 10, 2020 · 8 comments
Assignees
Labels
comp:data tf.data related issues TF 2.1 for tracking issues in 2.1 release type:feature Feature requests type:performance Performance Issue

Comments

@Flamefire
Copy link
Contributor
Flamefire commented Feb 10, 2020

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Mint
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.1.0
  • Python version: 3.7.4

Describe the current behavior

Dataset.shuffle is (essentially) described as buffering N elements, then choosing 1 out of those N to return. Hence the input data is processed in chunks. In the extreme consider a shuffle buffer of size 2: In the first epoch only the first 2 elements can be returned. And the following code (using the 3rd and 4th element) produces only values of 0-4:

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(2)
for _ in range(20):
    print(list(dataset.as_numpy_iterator())[2:4])

For good training performance (as in accuracy reaches high values fast) a complete shuffling of the dataset is required. This becomes obvious if one considers (accidentally or purposely) order sets of training data (in the MNIST example: First all zeros, then all ones, etc). There will be many batches mostly or completely consisting of only 1 label value. This does not work well with SGD approaches.

Some statistics on MNIST (validation after 10 epochs by shuffle buffer size):

  100: Eval loss: 0.29262512158124876, accuracy: 0.9161659
 1000: Eval loss: 0.2921471730925334,  accuracy: 0.9165665
10000: Eval loss: 0.2914975070131895,  accuracy: 0.9171675
60000: Eval loss: 0.29154436285488117, accuracy: 0.91696715

As you can see the accuracy increases with the buffer size with everything else constant.

Describe the expected behavior

The whole dataset should be shuffled. This requires the concept of random access datasets. I believe the TFRecord format supports random access(?) So the shuffle operation can take random data from the whole dataset.

Code to reproduce the issue

import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()


def make_datasets_unbatched():
    # Scaling MNIST data from (0, 255] to (0., 1.]
    def scale(image, label):
        image = tf.cast(image, tf.float32)
        image /= 255
        return image, label

    datasets, info = tfds.load(name='mnist',
                               with_info=True,
                               as_supervised=True)

    return {key: ds.map(scale).cache()
            for key, ds in datasets.items()}, info.splits


def build_and_compile_cnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32,
                               3,
                               activation='relu',
                               input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
                  optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
                  metrics=['accuracy'])
    return model


datasets, infos = make_datasets_unbatched()
model = build_and_compile_cnn_model()
# Init
model.train_on_batch(datasets['train'].batch(64))
WEIGHTS_PATH = './init-weights'
model.save_weights(WEIGHTS_PATH)

for buffer_size in (100, 1000, 5000, 10000, 60000):
    print("Buffer size: %s" % buffer_size)
    model.load_weights(WEIGHTS_PATH)

    model.fit(x=datasets['train'].shuffle(buffer_size).batch(64).repeat(),
              epochs=10,
              steps_per_epoch=infos['train'].num_examples // 64)
    eval_loss, eval_acc = model.evaluate(datasets['test'].batch(64),
                                         steps=infos['test'].num_examples //
                                         64)
    print("Eval loss: %s, accuracy: %s" % (eval_loss, eval_acc))
@ravikyram ravikyram added comp:data tf.data related issues TF 2.1 for tracking issues in 2.1 release type:performance Performance Issue labels Feb 11, 2020
@ravikyram
Copy link
Contributor

I have tried in colab with TF 2.1.0 and Nightly versions. Please, find the gist here. Thanks!

@ravikyram ravikyram assigned gowthamkpr and unassigned ravikyram Feb 11, 2020
@byronyi
Copy link
Contributor
byronyi commented Feb 11, 2020

I am working on a "TFIndexedDataset" RFC for externally shuffling dataset whose limit is beyond memory size. Let me know your preferred use case and I will consider adding it to the RFC. Thanks!

@gowthamkpr gowthamkpr added the type:feature Feature requests label Feb 11, 2020
@gowthamkpr gowthamkpr assigned jsimsa and unassigned gowthamkpr Feb 11, 2020
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 11, 2020
@Flamefire
Copy link
Contributor Author

@byronyi Basically a way to use either TFDS or TFRecord files where shuffling considers the whole dataset. To be more specific: I expect a dataset to know its size and provide random access to its elements where this is possible (and I assume for most datasets it is possible as they use images, videos or whatever files or lines which can be counted and ordered first, I even thought TFRecord files were made for that). With this 2 properties a random shuffle operation should exist, which produces every element from the dataset exactly once in a completely random order. So basically a dataset.shuffle(dataset.size) equivalent which doesn't preload everything in memory.

@grofte
Copy link
grofte commented Nov 5, 2020

Feature request:
Add a return_shuffled argument to the .cache() method on the tf.data.Datasets. Whether the dataset is cached in memory or on disk it should be doable to create a pointer index while caching and sample randomly from the index when reading from the cache. Obviously, you can't do that from a TFRecord file ccurrently and implementing it would be so much work. It also makes sense that data would be stored on a magnetic drive but cached on an SSD.

@Flamefire The recommended approach is to write your dataset out into multiple TFRecord files and then when loading the data doing it like so:

files = tf.data.Dataset.from_tensor_slices(filenames)
files = files.shuffle(len(filenames))
ds = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type='GZIP').prefetch(1),
                      num_parallel_calls=tf.data.experimental.AUTOTUNE,
                      deterministic=False, cycle_length=10)

That way you are reading from 10 files at once (still serially in each file though), feeding it to the shuffle buffer at the appropriate point in your pipeline. And each epoch you shuffle the files you are reading from and if you have saved the dataset in more than 10 files that will change the order that you get the files (i.e. if you have 20 files then the expectation is that 5 files will be the same in the first half of epochs 1 and 2).

@jsimsa I really think adding a shuffle capability to the cache of TF Dataset would be a slam-dunk. It would definitely let me remove a lot of code from my projects that only exists to deal with Tensorflows limited capability in shuffling ❤️

@jsimsa
Copy link
Contributor
jsimsa commented Nov 10, 2020

@grofte am I correct to assume that you would want the functionality of return_shuffled to be equivalent to cache().shuffle(buffer_size=NUM_ELEMENTS) while avoiding the extra buffer that shuffle introduces? In particular, the expected behavior during the first epoch (where the cache is constructed) would be that the first element is returned only after the entire cache has been constructed.

I don't think the return_shuffled is general enough (e.g. you would also need to specify the RNG seed management to achieve reproducibility and deterministic reshuffling across epochs). What I can imagine though is to introduce tf.data.Dataset.at(i) transformation which could be used to index into "indexable" datasets (such as cache or from_tensor_slices). The way one could use it then would be:

ds = ... # the dataset to be cached
ds = ds.cache()

indices = tf.data.Dataset.range(NUM_ELEMENTS)
indices = indices.shuffle(NUM_ELEMENTS)
ds = indices.flat_map(lambda index: ds.at(index))

This is a more general API, which handles the aspect of shuffling mentioned above and could be used with any datasets that supports indexing. The main disadvantage is that it requires the number of elements of the caches dataset to be known.

@grofte
Copy link
grofte commented Nov 11, 2020

You are 100% correct that an .at method would be an excellent extension. Probably very useful for researchers that need their data with some kind of deterministic properties. But for both applications you are going to need to read the whole dataset once without training - analogous to how the Keras preprocessing layers have an .adapt method. Because you need to link the indices to memory/disk adresses either way. Say you didn't and your first three indices were 100, 3, and 2000. Then you would have to read the first 100 elements/rows from the file, then 3, and then 1997. So 2100 reads to get three elements/rows. That's very ineffiencient.

As I understand it caching currently happens during the first training epoch for efficiency reasons. However, if you expect to train your model for many epochs then performing the caching before is only a small percentage increase in time and a large potential increase in model quality (and faster convergence which reduces training time). So you wouldn't have to know the number of elements in advance.

@byronyi
Copy link
Contributor
byronyi commented Nov 11, 2020

I am working on a "TFIndexedDataset" RFC for externally shuffling dataset whose limit is beyond memory size. Let me know your preferred use case and I will consider adding it to the RFC. Thanks!

The proposed IndexedDataset is stored with the very same on-disk format as that of cached datasets, but it could be stored in the remote storage without reading it through first. The dataset.at(index) API looks perfect, and if people in this thread still interest, I could continue work on this proposal.

@jsimsa
Copy link
Contributor
jsimsa commented Nov 11, 2020

@grofte I am a little bit confused by your example. I would assume that if the first three indices were 100, 3, and 2000, then you would need to read 100 elements to get the first results, and then either 0 (if the read elements are cached) or 3, and then either 1900 (if the read elements are cached) or 2000. So you would either need to read 2000 elements if the read elements are cached or 2103 elements if there is no caching -- neither of which matches your description.

Note that at will generally not require reading the whole dataset. For dataset that cannot be efficiently indexed into, it simply won't be supported (i.e. it will throw an informative error). The existing skip transformation could be used to perform this inefficient indexing that you are alluding to.

I was asking you about what would you expect the behavior of cache(..., returned_shuffled=True) be on the first epoch (e.g. when the data is not cached yet). In order for the method to be able to return any permutation of the input dataset, it needs to a) known the cardinality of the input dataset so that it knows the range of indices to consider and b) possibly require reading the entire input dataset if the first element to return after shuffling is the last element of its input.

@byronyi it makes sense for us to chat over VC about your proposal -- we have internal WIP proposal for indexed datasets as well and it would be great if we could align the two and I would be more than happy for you to lead the effort. I will reach out to you via email.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 13, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues TF 2.1 for tracking issues in 2.1 release type:feature Feature requests type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

7 participants