[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.py_function does not output ragged tensors #69777

Open
tirk999 opened this issue Jun 14, 2024 · 2 comments
Open

tf.py_function does not output ragged tensors #69777

tirk999 opened this issue Jun 14, 2024 · 2 comments
Assignees
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.15 For issues related to 2.15.x type:bug Bug

Comments

@tirk999
Copy link
tirk999 commented Jun 14, 2024

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

source

TensorFlow version

2.15

Custom code

Yes

OS platform and distribution

Windows

Mobile device

No response

Python version

3.9

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

Hello everyone,

I am having a problem with the function tf.py_function. I want it to output a ragged tensor but I cannot manage to make it happen.

The longer context is that I am training a neural network, YOLO, for object detection. I want to use some data augmentations techniques that do not support bounding boxes. Therefore, I moved to albumentations for data augmentation. The problem is that albuminations works on numpy arrays. I use tf.py_function to output my defined data augmentation but I would need the bounding boxes to be a ragged tensor.

Reference code: https://keras.io/examples/vision/yolov8/

# where map_augmentation is my function for augmentation. 
train_ds = train_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(BATCH_SIZE * 4)
train_ds = train_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
train_ds = train_ds.map(map_augmentation, num_parallel_calls=tf.data.AUTOTUNE)

Inside of the map_augmentation function:

augmented_images, augmented_classes, augmented_bboxes = tf.py_function(
            func=apply_augmentation,
            inp=[images[j,:,:,:], classes[j], bbox[j]],
            Tout=[tf.float32, tf.float32, tf.float32]
        )

I can modify the Tout in order to output a tf.TensorRaggedSpec but I always get errors.

Moreover, I tried to find a solution with this topic #26453 but I did not manage.
I manage to make this run:

tf = tensorflow.compat.v2
def py_func():
  return tf.ragged.constant([[1., 2., 3.], [4., 5.]])

print(tf.py_function(py_func, [], Tout=tf.RaggedTensorSpec([2, None], tf.float32)))  # <tf.RaggedTensor [[1.0, 2.0, 3.0], [4.0, 5.0]]>

but I have the issue when I declare an input for py_func.

Standalone code to reproduce the issue

elements=[np.array([1., 2., 3.]).astype('float32'), np.array([4., 5.]).astype('float32')]
def py_func(elements):
  return tf.ragged.constant([np.array([1., 2., 3.]).astype('float32'), np.array([4., 5.]).astype('float32')])

print(tf.py_function(py_func, [elements], Tout=tf.RaggedTensorSpec([2, None], tf.float32)))  


### Relevant log output

```shell
ValueError: Can't convert non-rectangular Python sequence to Tensor.
@google-ml-butler google-ml-butler bot added the type:bug Bug label Jun 14, 2024
@Venkat6871 Venkat6871 added the TF 2.15 For issues related to 2.15.x label Jun 17, 2024
@Venkat6871
Copy link

Hi @tirk999 ,

  • I tried to run your code on Colab using TF v2.16.1 and nightly faced the same issue. Please find the gist here for reference.
  • Thank you!

@tirk999
Copy link
Author
tirk999 commented Jun 17, 2024

Hello @Venkat6871 , do you think that the issue will be solved?
Thank you!

@Venkat6871 Venkat6871 added comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.15 For issues related to 2.15.x type:bug Bug
Projects
None yet
Development

No branches or pull requests

2 participants