[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integer data types to IsTrainable for use with custom gradients #25386

Open
jvmncs opened this issue Jan 31, 2019 · 9 comments
Open

Add integer data types to IsTrainable for use with custom gradients #25386

jvmncs opened this issue Jan 31, 2019 · 9 comments
Assignees
Labels
comp:ops OPs related issues stat:awaiting response Status - Awaiting response from author type:feature Feature requests

Comments

@jvmncs
Copy link
jvmncs commented Jan 31, 2019

System information

  • TensorFlow version (you are using): 1.12.0
  • Are you willing to contribute it (Yes/No): Yes

Describe the feature and the current behavior/state.
Currently, there are a limited number of DTypes usable with autograd (see IsTrainable and _IsBackpropagatable below)

def IsTrainable(tensor_or_dtype):

In particular, nodes with integer tensors automatically generate gradients of None, which breaks any call to tf.gradients. This is a reasonable assumption for usual practices in machine learning, however this prevents one from creating their own autograd system using tf.custom_gradient. I'd contend that, as a general purpose automatic differentiation library, TF should enable experimenting with such uses of its autograd system. This would be particularly useful when intending to perform automatic differentiation over rings or finite fields, which are often represented in computers as sets of consecutive integers (i.e. Z_p).

Will this change the current api? How?
This change shouldn't affect any standard usage of TensorFlow -- all autograd on floats will be unchanged. The only change will occur when calling autograd on graphs that compute integer arithmetic and similar.

Currently, performing tf.gradients on graphs with integer tensors will raise an unhandled error:

import numpy as np
import tensorflow as tf

x_back = np.ones([2, 2])
y_back = np.ones([2, 2])
x = tf.Variable(x_back, dtype=tf.int32)
y = tf.Variable(y_back, dtype=tf.int32)
z = x + y
vg = tf.gradients([z], [x, y])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    out = sess.run(vg)
    print(out)

Returns:

Traceback (most recent call last):
  File "/Users/jasonmancuso/dropout/research/customgrad/issue_min.py", line 16, in <module>
    out = sess.run(vg)
  File "/Users/jasonmancuso/anaconda/envs/tf-encrypted/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/Users/jasonmancuso/anaconda/envs/tf-encrypted/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1137, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/Users/jasonmancuso/anaconda/envs/tf-encrypted/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 471, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/Users/jasonmancuso/anaconda/envs/tf-encrypted/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 261, in for_fetch
    return _ListFetchMapper(fetch)
  File "/Users/jasonmancuso/anaconda/envs/tf-encrypted/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 370, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/Users/jasonmancuso/anaconda/envs/tf-encrypted/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 370, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/Users/jasonmancuso/anaconda/envs/tf-encrypted/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 258, in for_fetch
    type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

See #783 (comment) for other cases in which Ops can generate a None gradient.

After implementing this feature, we'd return the following:

[array([[1, 1], [1, 1]], dtype=int32),
 array([[1, 1], [1, 1]], dtype=int32)]

Who will benefit with this feature?
Anyone who wants to perform automatic differentiation on integer data types, as would be the case when operating in integer rings or finite fields.

Any Other info.
This request is motivated by the tf-encrypted project.

@ymodak ymodak added comp:ops OPs related issues type:feature Feature requests labels Jan 31, 2019
@skye
Copy link
Member
skye commented Feb 1, 2019

This is a cool use case, but unfortunately I think is too large a project to take on at this time. I think this would be a difficult feature to fully implement, because we'd have to make sure that every integer operation has a proper gradient function defined. Gradient functions are currently written with the assumption that we don't need to handle integers (e.g. we always return None for the indices param of a gather op). The result of this is, without careful auditing and/or testing, many gradients() calls over integers would silently return None or even the wrong answer. This is a big enough feature that we'd probably want close collaboration with someone on the TF team.

If there's enough demand in the future it might be worth the effort, but for now I don't think we can properly deliver on this.

cc @alextp @ebrevdo @martinwicke -- maybe I'm missing something that makes this more tractable?

@alextp
Copy link
Contributor
alextp commented Feb 1, 2019

I agree with @skye . It might be easier for you to bypass TF's gradient code entirely if you want to go this route, since our ops have gradients which don't behave well at all with integers.

@mortendahl
Copy link

@skye to understand the problem better and out of curiosity, would you mind pointing to a place where the assumption on returning None is used?

@skye
Copy link
Member
skye commented Feb 4, 2019

Here's one I happen to run into recently: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_grad.py#L410
Note that we always return None for the indices grad of gather. This is a niche case, but if you were attempting to take the gradient w.r.t. a variable that determines the indices, this grad function would silently return the wrong answer (None corresponds to zero gradient more or less). There are likely more, part of the difficulty is combing through all the gradient functions and finding them :)

@shaunster0
Copy link

I have this need for integer tensor and their gradients too ... it’s disappointing nothing seems to be happening with this, I think this will restrict future innovation in some interesting areas

@rmothukuru rmothukuru added the stat:contribution welcome Status - Contributions welcome label May 26, 2021
@mohantym mohantym self-assigned this Jul 27, 2022
@mohantym
Copy link
Contributor

Hi @jvmncs !

It is not throwing any error in 2.8 version. Attached gist for reference.

Thank you!

@mohantym mohantym added stat:awaiting response Status - Awaiting response from author and removed stat:contribution welcome Status - Contributions welcome labels Jul 27, 2022
@jvmncs
Copy link
Author
jvmncs commented Jul 27, 2022

Hi @mohantym, thanks for the response. It's nice to see that the error is no longer being thrown, but unfortunately I don't think that suffices to consider this issue solved. In particular, the check that forces a jump out of the backprop logic seems to be the same, just moved to a new location here. Thus it's still not possible for integer-typed tensors to have non-None gradients, even when defined explicitly with a tf.custom_gradient.

FWIW, I think my understanding of TF's goals & aims has evolved. I would not be surprised if this feature were considered a non-goal for the project. It's a relatively niche request. Anyway, I think Jax might be a bit better fit for something so experimental.

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Jul 27, 2022
@mohantym
Copy link
Contributor
mohantym commented Jul 28, 2022

@jvmncs !
I just checking for a update from your side. Thanks for the reply.

@mohantym mohantym assigned gadagashwini and unassigned mohantym Jul 28, 2022
@gadagashwini gadagashwini removed their assignment Aug 16, 2022
@tilakrayal
Copy link
Contributor

Hi,

Thank you for opening this issue. Since this issue has been open for a long time, the code/debug information for this issue may not be relevant with the current state of the code base.

The Tensorflow team is constantly improving the framework by fixing bugs and adding new features. We suggest you try the latest TensorFlow version with the latest compatible hardware configuration which could potentially resolve the issue. If you are still facing the issue, please create a new GitHub issue with your latest findings, with all the debugging information which could help us investigate.

Please follow the release notes to stay up to date with the latest developments which are happening in the Tensorflow space.

@tilakrayal tilakrayal added the stat:awaiting response Status - Awaiting response from author label Jun 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting response Status - Awaiting response from author type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

10 participants