[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing GPU op for zeros_like for RaggedTensorVariant, error occurs when Ragged Tensor fed thru tf.map_fn #46635

Open
djoshea opened this issue Jan 24, 2021 · 18 comments
Assignees
Labels
comp:gpu GPU related issues comp:keras Keras related issues TF 2.4 for issues related to TF 2.4 TF 2.16 type:bug Bug

Comments

@djoshea
Copy link
djoshea commented Jan 24, 2021

System information

Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes, included below
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.10
Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: n/a
TensorFlow installed from (source or binary): pip binary
TensorFlow version (use command below): v1.12.1-49539-g18d8bcbe72b 2.5.0-dev20210123
Python version: '3.8.6 | packaged by conda-forge | (default, Nov 27 2020, 19:31:52) \n[GCC 9.3.0]'
Bazel version (if compiling from source): n/a
GCC/Compiler version (if compiling from source): n/a
CUDA/cuDNN version: 11.0 / 8
GPU model and memory: TITAN X (Pascal) computeCapability: 6.1

Describe the current behavior

I have a keras layer RescaleB that accepts a ragged tensor with shape [batch, (time), in_dim]. The layer calls map_fn to process each example in the batch separately, scaling the values along the inner dimension by a trainable gain vector. (The details of the operation aren't critical, but the ragged tensor going into map_fn is.)

Using this layer fails with No unary variant unary_op function found for unary variant op enum: 1 Variant type_name: RaggedTensorVariant for device type: GPU on a node whose name ends with rescale_b/map/while/TensorArrayV2Write/TensorListSetItem_grad/zeros_like which suggests that the zeros_like operation
isn't defined for Ragged Tensors on GPU?

In this simple example, i also include RescaleA, which accomplishes the same task using tf.ragged.map_flat_values, although in my real use case I need map_fn. This is a simplified example.

Describe the expected behavior

I'd expect RescaleB and RescaleA to function identically.

Standalone code to reproduce the issue

https://colab.research.google.com/drive/1mHycCXJL94VuCGkXIJ0bIXtbYamyZo78

I've reproduced the issue locally with tf-nightly-gpu TF 2.5, but I can't seem to get the nightly version to see the GPU on Colab. The Colab notebook is using TF 2.4, but the issue remains in TF 2.5 nightly.

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

This may be the same issue as #44231 but hopefully the additional detail here is helpful.

@djoshea djoshea added the type:bug Bug label Jan 24, 2021
@djoshea djoshea changed the title Missing GPU op for zeros_like for Ragged Tensor, error occurs when Ragged Tensor fed thru tf.map_fn Missing GPU op for zeros_like for RaggedTensorVariant, error occurs when Ragged Tensor fed thru tf.map_fn Jan 24, 2021
@Saduf2019 Saduf2019 added TF 1.12 Issues related to TF 1.12 comp:keras Keras related issues comp:gpu GPU related issues labels Jan 25, 2021
@Saduf2019
Copy link
Contributor

I ran the code on tf 2.4 and and face a different error on nightly, please find the gist here

@Saduf2019 Saduf2019 added TF 2.4 for issues related to TF 2.4 and removed TF 1.12 Issues related to TF 1.12 labels Jan 25, 2021
@Saduf2019 Saduf2019 assigned rmothukuru and unassigned Saduf2019 Jan 25, 2021
@djoshea
Copy link
Author
djoshea commented Jan 25, 2021

As I mentioned, I don't think tf-nightly-gpu is running on the GPU in Colab. The issue is only present on GPU. In the gist you sent, the cell where it searches for the GPU returns:

Tensorflow version ==  2.5.0-dev20210124

---------------------------------------------------------------------------
SystemError                               Traceback (most recent call last)
<ipython-input-8-6f54130e32b7> in <module>()
      4 print(device_name)
      5 if device_name != '/device:GPU:0':
----> 6   raise SystemError('GPU device not found')
      7 print('Found GPU at: {}'.format(device_name))

SystemError: GPU device not found

Locally, where it is running on GPU, TF nightly fails with the same error as TF2.4, so this issue is still present in nightly, at least on my local pip install tf-nightly-gpu installation.

@rmothukuru rmothukuru assigned sanjoy and unassigned rmothukuru Jan 28, 2021
@rmothukuru rmothukuru added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 28, 2021
@sanjoy
Copy link
Contributor
sanjoy commented Feb 16, 2021

@djoshea this looks like a reasonable request, will you be able to send a PR for this?

@djoshea
Copy link
Author
djoshea commented Feb 16, 2021

I could potentially try developing it, but I unfortunately don't know where to start. Is there a guide somewhere to implementing new ops for the GPU?

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 18, 2021
@sanjoy
Copy link
Contributor
sanjoy commented Feb 26, 2021

I could potentially try developing it, but I unfortunately don't know where to start. Is there a guide somewhere to implementing new ops for the GPU?

I don't think we have a detailed guide for this, but if you're comfortable with C++ you could try to start at the check that's failing.

@patchmeifyoucan
Copy link

It would be nice if this could be implemented soon. Using map_fn with RaggedTensors is very convenient when working with data of different shapes. Unfortunately one must run these operations on CPU right now.

@areiner222
Copy link

I'm relying on map_fn to extract ragged image patches using bounding boxes so I can then produce ragged bounding-box centered feature maps via convolutions. I am encountering the same bug with the gradient step inside a tf.function (I can successfully run outside tf.function). The ragged image patches method is enabling a large speed up with a reduced memory footprint so it would be very useful to enable this in graph mode on GPUs.

@patchmeifyoucan
Copy link

For me RaggedTensors + map_fn are quite an enabler because I am implementing some kind of link prediction in a graph neural network. Having TensorFlow handling the individual samples allows for pretty safe programming as I do not have to put all samples in a batch into a large, disconnected graph and make sure that no nodes of different graphs will be connected. Using TensorFlow this way results in nice code that is fast to write in contrast to solutions with masks which may need to be changed every time you change your computation. Having this in GPU would be awesome. 💯

@OliverGuy
Copy link

I would love to see this fixed, especially as iterating over RaggedTensors seems to not properly work in a lot of cases.

@sushreebarsa
Copy link
Contributor

I ran the code in TF v2.5 and face error ,please find the gist here..Thanks !

@llan-ml
Copy link
llan-ml commented Jun 22, 2021

Hi @sanjoy any updates on this? This is a really needed feature.

@timolange
Copy link

Hi, i will also very appreciate if this bug get fixed.
I think it's somehow related to No gradient defined for operation RaggedTensorFromVariant and No gradient defined for operation RaggedTensorFromVariant / or no gradients at all
which are fixed with commit be6b1fd by @edloper .

The part with REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION( ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant, RaggedTensorVariantZerosLike<CPUDevice>); just registers the zeros_like OP for CPU device but not for GPU.

Maybe this will be a quick win to fix.
@edloper ,it would be very kind, if you can have a look at this.
Thanks.

@JXRiver
Copy link
Contributor
JXRiver commented Sep 2, 2021

Hi, thanks for raising the issue. According to @edloper "Basically, RaggedTensorVariant objects should never be copied to GPU, because we can't do anything useful with them there. But Placer isn't currently smart enough to figure that out (it just sees a Variant tensor, and doesn't know what kind of value it contains)." We have a project going on right now that hopefully will fix the issue.

@sushreebarsa sushreebarsa added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting response Status - Awaiting response from author labels Sep 3, 2021
@misha-antonenko
Copy link
misha-antonenko commented May 30, 2022

Hello. Could you please inform if there has been any progress on the issue?

For me, it resulted in a sudden sixfold increase of the train step duration due to having to move to CPU (compared to residing to a contrived solution with non-ragged tensors)

@ablanch5
Copy link
ablanch5 commented Dec 8, 2022

Also needing this resolved

@isarandi
Copy link
isarandi commented Oct 13, 2023

I faced the same issue and resolved it by rolling my own map_fn via TensorArray. Here's a rough sketch:

def map_fn2(fn, elems, fn_output_signature):
    batch_size = tf.shape(tf.nest.flatten(elems)[0])[0]
    arr = tf.TensorArray(
        fn_output_signature.dtype, size=batch_size, element_shape=fn_output_signature.shape)
    for i in tf.range(batch_size):
        arr = arr.write(i, fn(tf.nest.map_structure(lambda x: x[i], elems)))
    return arr.stack()

@ivkhar
Copy link
ivkhar commented Jan 17, 2024

As of TF v. 2.15, still not resolved.
It would be nice if the limitation on using RaggedTensors with map_fn on GPU will be mentioned on map_fn documentation.

@Ryandry1st
Copy link
Ryandry1st commented Feb 2, 2024

Agreed, I would like to see this resolved as it is pretty important functionality for using map_fn and ragged tensors. @sanjoy @djoshea

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:gpu GPU related issues comp:keras Keras related issues TF 2.4 for issues related to TF 2.4 TF 2.16 type:bug Bug
Projects
None yet
Development

No branches or pull requests