[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing slice of variable not possible #22059

Open
Hoeze opened this issue Sep 4, 2018 · 11 comments
Open

Optimizing slice of variable not possible #22059

Hoeze opened this issue Sep 4, 2018 · 11 comments
Assignees
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.11 Issues related to TF 2.11 type:bug Bug

Comments

@Hoeze
Copy link
Hoeze commented Sep 4, 2018

Applying the gradient of a variable slice currently results in a NotImplemented error of tf.train.Optimizer.

The following two examples are working:

### WORKING ###
X = tf.Variable(2, dtype=tf.float32)
y = tf.constant(10, dtype="float32")
loss = y - (X*X)

variables=[X]
gradient = tf.gradients(loss, variables)
gradient = [(g, v) for g, v in zip(gradient, variables)]
train_op = tf.train.AdamOptimizer().apply_gradients(gradient)
### WORKING ###
big_X = tf.Variable([2,3,4], dtype=tf.float32)
X = big_X[0]
y = tf.constant(10, dtype="float32")
loss = y - (X*X)

train_op = train_op = tf.train.AdamOptimizer().minimize(loss)

The following example throws an error:

### NOT WORKING ###
big_X = tf.Variable([2,3,4], dtype=tf.float32)
X = big_X[0]
y = tf.constant(10, dtype="float32")
loss = y - (X*X)

variables=[X]
gradient = tf.gradients(loss, variables)
gradient = [(g, v) for g, v in zip(gradient, variables)]
train_op = tf.train.AdamOptimizer().apply_gradients(gradient)

The error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-22-10282dee2005>", line 10, in <module>
    train_op = tf.train.AdamOptimizer().apply_gradients(gradient)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/optimizer.py", line 605, in apply_gradients
    update_ops.append(processor.update_op(self, grad))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/optimizer.py", line 189, in update_op
    raise NotImplementedError("Trying to update a Tensor ", self._v)
NotImplementedError: ('Trying to update a Tensor ', <tf.Tensor 'strided_slice_9:0' shape=() dtype=float32>)

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: NA
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v1.10.1-0-g4dcfddc5d1 1.10.1
  • Python version: 3.6.5
  • Bazel version (if compiling from source): NA
  • GCC/Compiler version (if compiling from source): NA
  • CUDA/cuDNN version: NA
  • GPU model and memory: NA
  • Exact command to reproduce: NA
@Hoeze
Copy link
Author
Hoeze commented Sep 5, 2018

I debugged this problem:
Since this variable slice is a Tensor, it gets wrapped by the class _TensorProcessor in tensorflow/python/training/optimizer.py.
Since update_op() is not implemented in this special case, the gradient cannot be applied.

There would be multiple solutions to this problem:

  • Make variable slices return another special variable type
  • Implement update_op() in _TensorProcessor for tensors supporting some .assign() method

@nellymin
Copy link

I cannot quite understand what are the reasons for this to not be implemented?
This, combined with a still unresolved issue #1325 where imported model's variables are not visible makes retraining a pretrained model practically impossible.

@jvishnuvardhan jvishnuvardhan self-assigned this Feb 13, 2019
@jvishnuvardhan jvishnuvardhan added comp:ops OPs related issues type:bug Bug labels Feb 13, 2019
@jvishnuvardhan
Copy link
Contributor

@Hoeze Is this still an issue? Thanks!

@Hoeze
Copy link
Author
Hoeze commented Feb 14, 2019

@jvishnuvardhan Yes, the issue is still the same

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 15, 2019
@jonas-eschle
Copy link
Contributor
jonas-eschle commented Apr 29, 2021

Are there any news on this? The implementation of this would be helpful for us as well

@maxhgerlach
Copy link
Contributor

Hi @rmlarsen, is some extension to lift this limitation on the TensorFlow roadmap? Applying an optimizer to a slice of a variable would unblock quite a few more flexible approaches to training certain classes of models.

@mohantym
Copy link
Contributor
mohantym commented Aug 11, 2021

@Hoeze !

I was able to replicate the issue in 2.11 and TF Nightly 2.12.0-dev20221215.

Thank you!

@mohantym mohantym added the TF 2.5 Issues related to TF 2.5 label Aug 12, 2021
@mohantym mohantym added TF 2.9 Issues found in the TF 2.9 release (or RCs) regression issue To spot regression issues in latest version and removed TF 2.5 Issues related to TF 2.5 labels Jul 19, 2022
@mohantym mohantym added TF 2.11 Issues related to TF 2.11 and removed TF 2.9 Issues found in the TF 2.9 release (or RCs) labels Feb 7, 2023
@SNMS95
Copy link
SNMS95 commented Mar 8, 2023

Are there any updates on this ?

@Emalude
Copy link
Emalude commented Jun 28, 2023

Hi, any updates about this issue? I'm using a library for Knowledge Graph Embedding models based on Tensorflow to develop an explainability gradient-based method by tracking the influence of single training samples. Given the high number of parameters of such models I think it would be very useful for Tensorflow to support the gradient computation of only a subset of training variables.

@tilakrayal
Copy link
Contributor

Hi,

Thank you for opening this issue. Since this issue has been open for a long time, the code/debug information for this issue may not be relevant with the current state of the code base.

The Tensorflow team is constantly improving the framework by fixing bugs and adding new features. We suggest you try the latest TensorFlow version with the latest compatible hardware configuration which could potentially resolve the issue. If you are still facing the issue, please create a new GitHub issue with your latest findings, with all the debugging information which could help us investigate.

Please follow the release notes to stay up to date with the latest developments which are happening in the Tensorflow space.

@tilakrayal tilakrayal added stat:awaiting response Status - Awaiting response from author and removed regression issue To spot regression issues in latest version labels May 23, 2024
@jonas-eschle
Copy link
Contributor

Actually, it seems to still be an issue, even running with the nightlies

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label May 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.11 Issues related to TF 2.11 type:bug Bug
Projects
None yet
Development

No branches or pull requests