[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Sparse Tensors in py_function #30069

Closed
novog opened this issue Jun 24, 2019 · 15 comments
Closed

Support Sparse Tensors in py_function #30069

novog opened this issue Jun 24, 2019 · 15 comments
Labels
comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:contribution welcome Status - Contributions welcome type:feature Feature requests

Comments

@novog
Copy link
novog commented Jun 24, 2019
  • TensorFlow version (you are using): 2.0.0b0
  • Are you willing to contribute it (Yes/No): No

Describe the feature and the current behavior/state.
Currently, sparse tensors don't seem to be supported as inputs to py_function. Attempting to pass one results in an error like "TypeError: Tensors in list passed to 'input' of 'EagerPyFunc' Op have types [<NOT CONVERTIBLE TO TENSOR>] that are invalid.". Example:

def dataset_map_sparse_test():
    input_data_sparse = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
    input_data_dense = tf.sparse.to_dense(input_data_sparse)

    ds_dense = tf.data.Dataset.from_tensor_slices(input_data_dense)
    ds_sparse = tf.data.Dataset.from_tensor_slices(input_data_sparse)

    def inner_fn(*input):
        return input

    def outer_fn(*input_data):
        return tf.py_function(
            inner_fn,
            input_data,
            (tf.int32,))

    # this works
    ds_dense_mapped = ds_dense.map(outer_fn)
    print(next(iter(ds_dense_mapped)))

    # this results in the error specified above
    ds_sparse_mapped = ds_sparse.map(outer_fn)
    print(next(iter(ds_sparse_mapped)))

Will this change the current api? How?
Yes. The inp parameter of py_function currently accepts a list of Tensor objects; this change would broaden it to accept either Tensors or SparseTensors. (Possibly _TensorLikes?)

Who will benefit with this feature?
Anyone wishing to pass sparse tensors to a py_function-wrapped function. There are situations where this wrapping is necessary; for example, functions passed to Dataset.map cannot perform certain operations unless the function is wrapped using py_function.

Any Other info.

@achandraa achandraa self-assigned this Jun 25, 2019
@achandraa achandraa added 2.0.0-beta0 comp:ops OPs related issues type:feature Feature requests labels Jun 25, 2019
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 27, 2019
@alextp
Copy link
Contributor
alextp commented Jun 27, 2019

I think now that we have compositetensor and typespec it should be straightforward to implement sparsetensor support in py_function by treating all composites generically.

I don't have the bandwidth to do this now, so leaving as contributions welcome.

@alextp alextp added stat:contribution welcome Status - Contributions welcome and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Jun 27, 2019
@alextp alextp removed their assignment Jun 27, 2019
@pradyumna1
Copy link

@alextp I would like to work on this, can I please go ahead?

@novog
Copy link
Author
novog commented Jul 2, 2019

Would "all composites" include ragged tensors, or should I create a separate feature request for that? Thanks!

@alextp
Copy link
Contributor
alextp commented Jul 2, 2019

Composites include ragged tensors and a lot of other things as well.

@alextp
Copy link
Contributor
alextp commented Jul 2, 2019

@pradyumna1 please go ahead!

@novog
Copy link
Author
novog commented Aug 14, 2019

I notice that someone had already created issue #26453 (Allow py_function to support functions that return RaggedTensor). Implementing this feature using composite tensors would also satisfy that issue.

@minhtriet
Copy link

I made changes in file op_def_library.py so that CompositeTensor is now allowed. However, there are some missing features of CompositeTensor that does not allow adding the node to the computation graph, for example ._as_tf_output(), which is used in function _create_c_op. That function only appears in class Tensor. Could you advise on this @alextp?

@novog
Copy link
Author
novog commented Oct 23, 2019

@minhtriet, in a similar feature request (#27679), @edloper provided a helpful workaround in the form of a wrapper around py_function, which he suggested could be integrated into _internal_py_func. Does that help?

@minhtriet
Copy link

Thank you @novog, I could not read work on it right now, would get back to you in a couple of days

@minhtriet
Copy link

So, I updated tf.py_function in my fork

Rerunning the code, the error is

2019-12-14 21:38:52.367270: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 4. Tune using inter_op_parallelism_threads for best performance.
(<tf.Tensor: id=20, shape=(4,), dtype=int32, numpy=array([1, 0, 0, 0])>,)
2019-12-14 21:38:53.274875: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at iterator_ops.cc:929 : Invalid argument: {{function_node __inference_Dataset_map_outer_fn_28}} pyfunc_1 returns 3 values, but expects to see 1 values.
	 [[{{node PyFuncStateless}}]]
Traceback (most recent call last):
  File "C:/Users/cool/Code/tf_contrib_conda/test.py", line 28, in <module>
    dataset_map_sparse_test()
  File "C:/Users/cool/Code/tf_contrib_conda/test.py", line 26, in dataset_map_sparse_test
    print(next(iter(ds_sparse_mapped)))
  File "C:\Users\cool\Miniconda3\envs\tf_contrib_conda\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 622, in __next__
    return self.next()
  File "C:\Users\cool\Miniconda3\envs\tf_contrib_conda\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 666, in next
    return self._next_internal()
  File "C:\Users\cool\Miniconda3\envs\tf_contrib_conda\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 651, in _next_internal
    output_shapes=self._flat_output_shapes)
  File "C:\Users\cool\Miniconda3\envs\tf_contrib_conda\lib\site-packages\tensorflow_core\python\ops\gen_dataset_ops.py", line 2672, in iterator_get_next_sync
    _six.raise_from(_core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __inference_Dataset_map_outer_fn_28}} pyfunc_1 returns 3 values, but expects to see 1 values.
	 [[{{node PyFuncStateless}}]] [Op:IteratorGetNextSync]

My guess is that the creation of sparsed tensor works, but somehow next() does not work because pyfunc_1 returned 3 values, which I could not debug. Could someone give me a pointer?

@alextp
Copy link
Contributor
alextp commented Dec 16, 2019

This issue likely happens because the sparsetensor isn't being unpacked into its 3 component values.

@minhtriet
Copy link

Hi @alextp, thank you for the reply. Right now I am using tensorflow_core.python.util.nest.pack_sequence_as to produce the output (here), but somehow _is_composite_tensor returns False, which I do not understand why

The call stack is

_sequence_like, nest.py:149
pack_sequence_as, nest.py:471
_get_defun_inputs, func_graph.py:1160
_get_defun_inputs_from_args, func_graph.py:1062
func_graph_from_py_func, func_graph.py:837
_create_graph_function, function.py:2041
_maybe_define_function, function.py:2150
_get_concrete_function_internal_garbage_collected, function.py:1848
_get_concrete_function_internal, function.py:1854
__init__, dataset_ops.py:2695
__init__, dataset_ops.py:3416
map, dataset_ops.py:1211
dataset_map_sparse_test, test.py:25          << This is the test from OP
<module>, test.py:28

@alextp
Copy link
Contributor
alextp commented Jan 2, 2020

That doesn't make sense to me because SparseTensor is a composite:

class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):

So it's likely something else that is going on here.

@github-actions
Copy link

This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Mar 28, 2023
Copy link

This issue was closed because it has been inactive for 1 year.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:contribution welcome Status - Contributions welcome type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

6 participants