-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exposes num_parallel_reads and num_parallel_calls #1232
Changes from 1 commit
21bc6c7
fb93e5c
e3c8742
5203f5d
fe683eb
8c4973c
8de7138
bbd426a
787ddc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,11 +16,28 @@ | |
|
||
import tensorflow as tf | ||
from tensorflow_io.core.python.ops import core_ops | ||
from typing import Optional | ||
|
||
_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB | ||
_DEFAULT_READER_SCHEMA = "" | ||
# From https://github.com/tensorflow/tensorflow/blob/v2.0.0/tensorflow/python/data/ops/readers.py | ||
|
||
|
||
def _require(condition: bool, err_msg: Optional[str] = None) -> None: | ||
"""Checks if the specified condition is true else raises exception | ||
|
||
:param condition: The condition to test | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use consistent docstring style There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed, please review. |
||
|
||
:param err_msg: If specified, it's the error message to use if condition is not true. | ||
|
||
:raises ValueError: Raised when the condition is false | ||
|
||
:return: None. | ||
""" | ||
if not condition: | ||
raise ValueError(err_msg) | ||
|
||
|
||
# copied from https://github.com/tensorflow/tensorflow/blob/ | ||
# 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L36 | ||
def _create_or_validate_filenames_dataset(filenames): | ||
|
@@ -52,21 +69,62 @@ def _create_or_validate_filenames_dataset(filenames): | |
|
||
# copied from https://github.com/tensorflow/tensorflow/blob/ | ||
# 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L67 | ||
def _create_dataset_reader(dataset_creator, filenames, num_parallel_reads=None): | ||
"""create_dataset_reader""" | ||
|
||
def read_one_file(filename): | ||
filename = tf.convert_to_tensor(filename, tf.string, name="filename") | ||
return dataset_creator(filename) | ||
|
||
if num_parallel_reads is None: | ||
return filenames.flat_map(read_one_file) | ||
if num_parallel_reads == tf.data.experimental.AUTOTUNE: | ||
return filenames.interleave( | ||
read_one_file, num_parallel_calls=num_parallel_reads | ||
) | ||
def _create_dataset_reader( | ||
dataset_creator, | ||
filenames, | ||
cycle_length=None, | ||
num_parallel_calls=None, | ||
deterministic=None, | ||
block_length=1, | ||
): | ||
""" | ||
This creates a dataset reader which reads records from multiple files and interleaves them together | ||
``` | ||
dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] | ||
# NOTE: New lines indicate "block" boundaries. | ||
dataset = dataset.interleave( | ||
lambda x: Dataset.from_tensors(x).repeat(6), | ||
cycle_length=2, block_length=4) | ||
list(dataset.as_numpy_iterator()) | ||
``` | ||
Results in the following output: | ||
[1,1,1,1, | ||
2,2,2,2, | ||
1,1, | ||
2,2, | ||
3,3,3,3, | ||
4,4,4,4, | ||
3,4, | ||
5,5,5,5, | ||
5,5, | ||
] | ||
Args: | ||
dataset_creator: Initializer for AvroDatasetRecord | ||
filenames: A `tf.data.Dataset` iterator of filenames to read | ||
cycle_length: The number of files to be processed in parallel. This is used by `Dataset.Interleave`. | ||
We set this equal to `block_length`, so that each time n number of records are returned for each of the n | ||
files. | ||
num_parallel_calls: Number of threads spawned by the interleave call. | ||
deterministic: Sets whether the interleaved records are written in deterministic order. in tf.interleave thi sis default true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
block_length: Sets the number of output on the output tensor. Defaults to 1 | ||
Returns: | ||
A dataset iterator with an interleaved list of parsed avro records. | ||
|
||
""" | ||
|
||
def read_many_files(filenames): | ||
filenames = tf.convert_to_tensor(filenames, tf.string, name="filename") | ||
return dataset_creator(filenames) | ||
|
||
if cycle_length is None: | ||
return filenames.flat_map(read_many_files) | ||
|
||
return filenames.interleave( | ||
read_one_file, cycle_length=num_parallel_reads, block_length=1 | ||
read_many_files, | ||
cycle_length=cycle_length, | ||
num_parallel_calls=num_parallel_calls, | ||
block_length=block_length, | ||
deterministic=deterministic, | ||
) | ||
|
||
|
||
|
@@ -128,7 +186,14 @@ class AvroRecordDataset(tf.data.Dataset): | |
"""A `Dataset` comprising records from one or more AvroRecord files.""" | ||
|
||
def __init__( | ||
self, filenames, buffer_size=None, num_parallel_reads=None, reader_schema=None | ||
self, | ||
filenames, | ||
buffer_size=None, | ||
num_parallel_reads=None, | ||
num_parallel_calls=None, | ||
reader_schema=None, | ||
deterministic=True, | ||
block_length=1, | ||
): | ||
"""Creates a `AvroRecordDataset` to read one or more AvroRecord files. | ||
|
||
|
@@ -144,25 +209,65 @@ def __init__( | |
files read in parallel are outputted in an interleaved order. If your | ||
input pipeline is I/O bottlenecked, consider setting this parameter to a | ||
value greater than one to parallelize the I/O. If `None`, files will be | ||
read sequentially. | ||
read sequentially. This must be set to equal or greater than `num_parallel_calls`. | ||
This constraint exists because `num_parallel_reads` becomes `cycle_length` in the | ||
underlying call to `tf.Dataset.Interleave`, and the `cycle_length` is required to be | ||
equal or higher than the number of threads(`num_parallel_calls`). | ||
`cycle_length` in tf.Dataset.Interleave will dictate how many items it will pick up to process | ||
num_parallel_calls: (Optional.) number of thread to spawn. This must be set to `None` | ||
or greater than 0. Also this must be less than or equal to `num_parallel_reads`. This defines | ||
the degree of parallelism in the underlying Dataset.interleave call. | ||
reader_schema: (Optional.) A `tf.string` scalar representing the reader | ||
schema or None. | ||
<<<<<<< HEAD | ||
|
||
======= | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conflicts here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed, please review. |
||
deterministic: (Optional.) A boolean controlling whether determinism should be traded for performance by | ||
allowing elements to be produced out of order. Defaults to `True` | ||
block_length: Sets the number of output on the output tensor. Defaults to 1 | ||
>>>>>>> d41d946... Added parameter constraints | ||
Raises: | ||
TypeError: If any argument does not have the expected type. | ||
ValueError: If any argument does not have the expected shape. | ||
""" | ||
_require( | ||
num_parallel_calls is None | ||
or num_parallel_calls == tf.data.experimental.AUTOTUNE | ||
or num_parallel_calls > 0, | ||
f"num_parallel_calls: {num_parallel_calls} must be set to None, " | ||
f"tf.data.experimental.AUTOTUNE, or greater than 0", | ||
) | ||
if num_parallel_calls is not None: | ||
_require( | ||
num_parallel_reads is not None | ||
and ( | ||
num_parallel_reads >= num_parallel_calls | ||
or num_parallel_reads == tf.data.experimental.AUTOTUNE | ||
), | ||
f"num_parallel_reads: {num_parallel_reads} must be greater than or equal to " | ||
f"num_parallel_calls: {num_parallel_calls} or set to tf.data.experimental.AUTOTUNE", | ||
) | ||
|
||
filenames = _create_or_validate_filenames_dataset(filenames) | ||
|
||
self._filenames = filenames | ||
self._buffer_size = buffer_size | ||
self._num_parallel_reads = num_parallel_reads | ||
self._num_parallel_calls = num_parallel_calls | ||
self._reader_schema = reader_schema | ||
self._block_length = block_length | ||
|
||
def creator_fn(filename): | ||
return _AvroRecordDataset(filename, buffer_size, reader_schema) | ||
def read_multiple_files(filenames): | ||
return _AvroRecordDataset(filenames, buffer_size, reader_schema) | ||
|
||
self._impl = _create_dataset_reader(creator_fn, filenames, num_parallel_reads) | ||
self._impl = _create_dataset_reader( | ||
read_multiple_files, | ||
filenames, | ||
cycle_length=num_parallel_reads, | ||
num_parallel_calls=num_parallel_calls, | ||
deterministic=deterministic, | ||
block_length=block_length, | ||
) | ||
variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access | ||
super().__init__(variant_tensor) | ||
|
||
|
@@ -171,13 +276,17 @@ def _clone( | |
filenames=None, | ||
buffer_size=None, | ||
num_parallel_reads=None, | ||
num_parallel_calls=None, | ||
reader_schema=None, | ||
block_length=None, | ||
): | ||
return AvroRecordDataset( | ||
filenames or self._filenames, | ||
buffer_size or self._buffer_size, | ||
num_parallel_reads or self._num_parallel_reads, | ||
num_parallel_calls or self._num_parallel_calls, | ||
reader_schema or self._reader_schema, | ||
block_length or self._block_length, | ||
) | ||
|
||
def _inputs(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,6 @@ def make_avro_record_dataset( | |
shuffle_seed=None, | ||
prefetch_buffer_size=tf.data.experimental.AUTOTUNE, | ||
num_parallel_reads=None, | ||
num_parallel_parser_calls=None, | ||
drop_final_batch=False, | ||
): | ||
"""Reads and (optionally) parses avro files into a dataset. | ||
|
@@ -79,14 +78,26 @@ def make_avro_record_dataset( | |
prefetch_buffer_size: (Optional.) An int specifying the number of | ||
feature batches to prefetch for performance improvement. | ||
Defaults to auto-tune. Set to 0 to disable prefetching. | ||
<<<<<<< HEAD | ||
<<<<<<< HEAD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ashahab Can you resolve the merge conflict here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yongtang absolutely, sorry about that. @StanfordMCP There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yongtang Fixed the issue, please take a look. |
||
|
||
num_parallel_reads: (Optional.) Number of threads used to read | ||
records from files. By default or if set to a value >1, the | ||
results will be interleaved. | ||
|
||
num_parallel_parser_calls: (Optional.) Number of parallel | ||
======= | ||
num_parallel_calls: (Optional.) Number of threads used to read | ||
records from files. By default or if set to a value >1, the | ||
results will be interleaved. | ||
num_parallel_reads: (Optional.) Number of parallel | ||
>>>>>>> f7032e3... Exposed num_parallel_reads as well as num_parallel_calls | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
records to parse in parallel. Defaults to an automatic selection. | ||
|
||
======= | ||
num_parallel_reads: (Optional.) Number of parallel | ||
records to parse in parallel. Defaults to None(no parallelization). | ||
>>>>>>> d41d946... Added parameter constraints | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
drop_final_batch: (Optional.) Whether the last batch should be | ||
dropped in case its size is smaller than `batch_size`; the | ||
default behavior is not to drop the smaller batch. | ||
|
@@ -99,20 +110,15 @@ def make_avro_record_dataset( | |
""" | ||
files = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle, seed=shuffle_seed) | ||
|
||
if num_parallel_reads is None: | ||
# Note: We considered auto-tuning this value, but there is a concern | ||
# that this affects the mixing of records from different files, which | ||
# could affect training convergence/accuracy, so we are defaulting to | ||
# a constant for now. | ||
num_parallel_reads = 24 | ||
|
||
if reader_buffer_size is None: | ||
reader_buffer_size = 1024 * 1024 | ||
|
||
num_parallel_calls = num_parallel_reads | ||
dataset = AvroRecordDataset( | ||
files, | ||
buffer_size=reader_buffer_size, | ||
num_parallel_reads=num_parallel_reads, | ||
num_parallel_calls=num_parallel_calls, | ||
block_length=num_parallel_calls, | ||
reader_schema=reader_schema, | ||
) | ||
|
||
|
@@ -131,14 +137,11 @@ def make_avro_record_dataset( | |
|
||
dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) | ||
|
||
if num_parallel_parser_calls is None: | ||
num_parallel_parser_calls = tf.data.experimental.AUTOTUNE | ||
|
||
dataset = dataset.map( | ||
lambda data: parse_avro( | ||
serialized=data, reader_schema=reader_schema, features=features | ||
), | ||
num_parallel_calls=num_parallel_parser_calls, | ||
num_parallel_calls=num_parallel_calls, | ||
) | ||
|
||
if prefetch_buffer_size == 0: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are not using any type checker's (like
mypy
) as of now. I feel this style is a bit out of place when compared with other modules in the codebase.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kvignesh1420 Thanks for the comment. Updated, please review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!