[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow py_function to support functions that return RaggedTensor #26453

Closed
woodshop opened this issue Mar 7, 2019 · 4 comments
Closed

Allow py_function to support functions that return RaggedTensor #26453

woodshop opened this issue Mar 7, 2019 · 4 comments
Assignees
Labels
comp:ops OPs related issues type:feature Feature requests

Comments

@woodshop
Copy link
woodshop commented Mar 7, 2019

System information

  • TensorFlow version (you are using): 1.13.1
  • Are you willing to contribute it (Yes/No): No

Describe the feature and the current behavior/state.
py_function only supports functions that return Tensors. However the wrapped function is executed in eager mode and therefore ideally should support other return types consistent with eager mode. As of tf-1.13.1, py_function attempts to convert the returned objects to Tensors, but a RaggedTensor cannot be converted directly to a Tensor. Any attempt to return a RaggedTensor raises an exception during the attempted conversion. E.g.,

with tf.Graph().as_default():
    elements = [[1., 2., 3.], [4., 5.]]
    ragged1 = tf.ragged.constant(elements)
    def py_func():
        return tf.ragged.constant(elements)
    
    ragged2 = tf.py_function(
        py_func, [], tf.dtypes.float32
    )
    with tf.Session() as sess:
        print(sess.run(ragged1))
        print(sess.run(ragged2))

>>> <tf.RaggedTensorValue [[1.0, 2.0, 3.0], [4.0, 5.0]]>
>>> ...
>>> ValueError: TypeError: object of type 'RaggedTensor' has no len()

I propose that py_function detect which output arguments, if any, are RaggedTensors and returns them without attempting to convert to Tensors. If the proposal is rejected, I suggest that the documentation is updated to make clearer (either in the API or guides) that a RaggedTensor is not a suitable return type for functions wrapped by py_function.

Will this change the current api? How? No.

Who will benefit with this feature? Anyone who uses RaggedTensor in conjunction with py_function.

Any Other info.
As a workaround, one can construct a RaggedTensor from the output of py_function. E.g.,

with tf.Graph().as_default():
    elements = [[1., 2., 3.], [4., 5.]]
    ragged1 = tf.ragged.constant(elements)
    def py_func():
        lengths = [len(element) for element in elements]
        return sum(elements, []), lengths
    
    concatenated, lengths = tf.py_function(
        py_func, [], [tf.dtypes.float32, tf.dtypes.int64]
    )
    ragged2 = tf.RaggedTensor.from_row_lengths(concatenated, lengths)
    with tf.Session() as sess:
        print(sess.run(ragged1))
        print(sess.run(ragged2))

>>> <tf.RaggedTensorValue [[1.0, 2.0, 3.0], [4.0, 5.0]]>
>>> <tf.RaggedTensorValue [[1.0, 2.0, 3.0], [4.0, 5.0]]>
@ymodak ymodak added comp:ops OPs related issues type:feature Feature requests labels Mar 7, 2019
@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 7, 2019
@edloper
Copy link
Contributor
edloper commented Feb 18, 2021

In #27679 (comment), I showed how py_function could be extended to handle composite tensor inputs and outputs (and other nested structures, like dicts, tuples, lists, etc). If you have bandwidth to work on a PR that adds that to TensorFlow (with tests etc.), then it would be very welcome; otherwise, you could just use the new_py_function that I defined there, which wraps tf.py_function.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 21, 2021
@JXRiver
Copy link
Contributor
JXRiver commented Nov 29, 2021

The issue should have been fixed by @edloper. The following example should work.

import tensorflow

# TF1
tf = tensorflow.compat.v1
with tf.Graph().as_default():
    elements = [[1., 2., 3.], [4., 5.]]
    ragged1 = tf.ragged.constant(elements)
    def py_func():
        return tf.ragged.constant(elements)
    
    ragged2 = tf.py_function(
        py_func, [], Tout=tf.RaggedTensorSpec([2, None], tf.float32)
    )
    with tf.Session() as sess:
        print(sess.run(ragged1))  # <tf.RaggedTensorValue [[1.0, 2.0, 3.0], [4.0, 5.0]]>
        print(sess.run(ragged2))  # <tf.RaggedTensorValue [[1.0, 2.0, 3.0], [4.0, 5.0]]>

# TF2
tf = tensorflow.compat.v2
def py_func():
  return tf.ragged.constant([[1., 2., 3.], [4., 5.]])

print(tf.py_function(py_func, [], Tout=tf.RaggedTensorSpec([2, None], tf.float32)))  # <tf.RaggedTensor [[1.0, 2.0, 3.0], [4.0, 5.0]]>

Note the main difference of this example with the example in the original issue is the Tout argument of tf.py_function. When tf.py_function returns a CompositeTensor, the Tout argument should be a subclass of tf.TypeSpec, i.e., tf.RaggedTensorSpec for a RaggedTensor.

See https://www.tensorflow.org/api_docs/python/tf/py_function for more information on tf.py_function and https://www.tensorflow.org/api_docs/python/tf/TypeSpec on tf.TypeSpec.

@JXRiver JXRiver closed this as completed Nov 29, 2021
@HuangChiEn
Copy link

The issue should have been fixed by @edloper. The following example should work.

import tensorflow

# TF1
tf = tensorflow.compat.v1
with tf.Graph().as_default():
    elements = [[1., 2., 3.], [4., 5.]]
    ragged1 = tf.ragged.constant(elements)
    def py_func():
        return tf.ragged.constant(elements)
    
    ragged2 = tf.py_function(
        py_func, [], Tout=tf.RaggedTensorSpec([2, None], tf.float32)
    )
    with tf.Session() as sess:
        print(sess.run(ragged1))  # <tf.RaggedTensorValue [[1.0, 2.0, 3.0], [4.0, 5.0]]>
        print(sess.run(ragged2))  # <tf.RaggedTensorValue [[1.0, 2.0, 3.0], [4.0, 5.0]]>

# TF2
tf = tensorflow.compat.v2
def py_func():
  return tf.ragged.constant([[1., 2., 3.], [4., 5.]])

print(tf.py_function(py_func, [], Tout=tf.RaggedTensorSpec([2, None], tf.float32)))  # <tf.RaggedTensor [[1.0, 2.0, 3.0], [4.0, 5.0]]>

Note the main difference of this example with the example in the original issue is the Tout argument of tf.py_function. When tf.py_function returns a CompositeTensor, the Tout argument should be a subclass of tf.TypeSpec, i.e., tf.RaggedTensorSpec for a RaggedTensor.

See https://www.tensorflow.org/api_docs/python/tf/py_function for more information on tf.py_function and https://www.tensorflow.org/api_docs/python/tf/TypeSpec on tf.TypeSpec.

Thanks for your brief tutorial, but the given code will not work, if I convert it into the lambda function.
image

image

image

look forward to see the further improvement of tf..

@edloper
Copy link
Contributor
edloper commented Dec 6, 2021

@HuangChiEn Are you using TensorFlow 2.7? Support for using composite tensors (such as RaggedTensor) with py_function was added with 2.7, so if you're using an earlier version of TensorFlow, then it won't work. I tried executing your code as written:

tmp = lambda _: tf.ragged.constant([[1., 2., 3.], [4., 5.]])
tf.py_function(tmp, [], Tout=tf.RaggedTensorSpec([2, None], tf.float32))

And it failed with "<lambda>() missing 1 required positional argument: '_'" (which is expected, since your lambda takes one argument, but you didn't supply any arguments when you called tf.py_function). If I change it the lambda to not expect any argument:

tmp = lambda: tf.ragged.constant([[1., 2., 3.], [4., 5.]])
tf.py_function(tmp, [], Tout=tf.RaggedTensorSpec([2, None], tf.float32))

Then it succeeds for me (in TF 2.7).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

6 participants