[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Gradient Computation not working in TF 2.14 #62053

Open
ayulockin opened this issue Oct 5, 2023 · 4 comments
Open

Custom Gradient Computation not working in TF 2.14 #62053

ayulockin opened this issue Oct 5, 2023 · 4 comments
Assignees
Labels
comp:ops OPs related issues regression issue To spot regression issues in latest version stat:awaiting tensorflower Status - Awaiting response from tensorflower TF2.14 For issues related to Tensorflow 2.14.x type:bug Bug

Comments

@ayulockin
Copy link

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

binary

TensorFlow version

2.14.0

Custom code

Yes

OS platform and distribution

Linux

Mobile device

No response

Python version

3.10.12

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

I am using W&B's Keras callback WandbCallback. This callback has a feature to log gradients of each layer at every step. This feature works fine till TF 2.13.0 but is erroring out in TF 2.14.0.

This piece of code works fine in Tf 2.13.0 but errors out in TF 2.14.0:

import numpy as np
import tensorflow as tf
print(tf.__version__)
import wandb
from wandb.keras import WandbModelCheckpoint
from wandb.keras import WandbCallback

run = wandb.init(project="keras")

x = np.random.randint(255, size=(100, 28, 28, 1))
y = np.random.randint(10, size=(100,))

dataset = (x, y)


def get_model():
    m = tf.keras.Sequential()
    m.add(tf.keras.layers.Conv2D(3, 3, activation="relu", input_shape=(28, 28, 1)))
    m.add(tf.keras.layers.Flatten())
    m.add(tf.keras.layers.Dense(10, activation="softmax"))
    return m


model = get_model()
model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="sgd",
    metrics=["accuracy"],
)

model.fit(
    x,
    y,
    epochs=5,
    validation_data=(x, y),
    callbacks=[
        WandbCallback(
            save_model=False,
            log_gradients=True,
            training_data=(x,y)
        )
    ],
)

I investigated further and was able to narrow it down to the gradient logging logic which again works fine for 2.13.0 but not for 2.14.0.

I think this has to do with the breaking changes with tf.Tensor.

The piece of code below is the gradient logging logic which errors out in the latest version.

Standalone code to reproduce the issue

import tensorflow as tf
print(tf.__version__)
import wandb
import numpy as np

_training_data_x = np.random.randint(255, size=(100, 28, 28, 1))
_training_data_y = np.random.randint(10, size=(100,))


def get_model():
    m = tf.keras.Sequential()
    m.add(tf.keras.layers.Conv2D(3, 3, activation="relu", input_shape=(28, 28, 1)))
    m.add(tf.keras.layers.Flatten())
    m.add(tf.keras.layers.Dense(10, activation="softmax"))
    return m

model = get_model()
model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="sgd",
    metrics=["accuracy"],
)


def _get_custom_optimizer_parent_class():
    from pkg_resources import parse_version

    if parse_version(tf.__version__) >= parse_version("2.9.0"):
        custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
    else:
        custom_optimizer_parent_class = tf.keras.optimizers.Optimizer

    return custom_optimizer_parent_class


_custom_optimizer_parent_class = _get_custom_optimizer_parent_class()
print(_custom_optimizer_parent_class)


class _CustomOptimizer(_custom_optimizer_parent_class):
    def __init__(self):
        super().__init__(name="CustomOptimizer")
        self._resource_apply_dense = tf.function(self._resource_apply_dense)
        self._resource_apply_sparse = tf.function(self._resource_apply_sparse)
        tf.print(self._resource_apply_dense)

    def _resource_apply_dense(self, grad, var):
        var.assign(grad)

    # this needs to be implemented to prevent a NotImplementedError when
    # using Lookup layers.
    def _resource_apply_sparse(self, grad, var, indices):
        pass

    def get_config(self):
        return super().get_config()


class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
    """Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above."""

    def set_model(self, model):
        super().set_model(model)
        self.og_weights = model.get_weights()
        self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]

    def on_batch_end(self, batch, logs=None):
        for g, w in zip(self.grads, self.model.trainable_weights):
            g += w.numpy()
        self.model.set_weights(self.og_weights)

    def get_grads(self):
        return [g.copy() for g in self.grads]


inputs = model.inputs
print(inputs)
outputs = model(inputs)
grad_acc_model = tf.keras.models.Model(inputs, outputs)
grad_acc_model.compile(loss=model.loss, optimizer=_CustomOptimizer())

_grad_accumulator_model = grad_acc_model
_grad_accumulator_model.summary()

_grad_accumulator_callback = _GradAccumulatorCallback()


_grad_accumulator_model.fit(
    _training_data_x,
    _training_data_y,
    verbose=0,
    callbacks=[_grad_accumulator_callback],
)

weights = model.trainable_weights
grads = _grad_accumulator_callback.grads
print(weights)

metrics = {}
for weight, grad in zip(weights, grads):
    metrics[
        "gradients/" + weight.name.split(":")[0] + ".gradient"
    ] = wandb.Histogram(grad)

print(metrics)

Relevant log output

Traceback (most recent call last):
  File "/home/ayushthakur/client/wandb/test_grad_logging.py", line 88, in <module>
    _grad_accumulator_model.fit(
  File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_file4zq8l42d.py", line 15, in tf__train_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
TypeError: in user code:

    File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/engine/training.py", line 1377, in train_function  *
        return step_function(self, iterator)
    File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/engine/training.py", line 1360, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/engine/training.py", line 1349, in run_step  **
        outputs = model.train_step(data)
    File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/engine/training.py", line 1130, in train_step
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/optimizers/legacy/optimizer_v2.py", line 601, in minimize
        return self.apply_gradients(grads_and_vars, name=name)
    File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/optimizers/legacy/optimizer_v2.py", line 760, in apply_gradients
        return tf.__internal__.distribute.interim.maybe_merge_call(
    File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/optimizers/legacy/optimizer_v2.py", line 844, in _distributed_apply
        with tf.control_dependencies([tf.group(update_ops)]):

    TypeError: 'inputs' should be zero or more (nested) Tensors. Received 'None' with type '<class 'NoneType'>'.
@ayulockin
Copy link
Author

From the Relevant Log Output above, if I remove tf.group from this line File "/opt/conda/envs/tf214/lib/python3.10/site-packages/keras/src/optimizers/legacy/optimizer_v2.py", line 844, in _distributed_apply, the issue seems to be resolved. I am not sure how to properly resolve this issue though, any help would be appreciated.

@sushreebarsa sushreebarsa added the TF2.14 For issues related to Tensorflow 2.14.x label Oct 6, 2023
@sushreebarsa
Copy link
Contributor

@ayulockin Thank you for raising this issue!
I was able to replicate the issue on colab using TF v2.14 and nightly. This issue is not appearing in 2.13.
@sachinprasadhs Could you please have a look. Thank you!

@sushreebarsa sushreebarsa added the comp:ops OPs related issues label Oct 6, 2023
@sushreebarsa sushreebarsa added the regression issue To spot regression issues in latest version label Oct 6, 2023
@ayulockin
Copy link
Author

Thanks for quickly replicating it @sushreebarsa. Looking for a response from @sachinprasadhs :)

@sachinprasadhs sachinprasadhs added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 10, 2023
@mywork-team1
Copy link
mywork-team1 commented Dec 19, 2023

I solved the issue like this:

class CustomSGD(tf.keras.optimizers.legacy.Optimizer):
pass
...

Creating an instance of the custom SGD optimizer

custom_sgd_optimizer = CustomSGD(learning_rate=0.01)

When compiling the model, I added the run_eagerly=True attribute:

model3.compile(optimizer=custom_sgd_optimizer, loss='mse', metrics=['mse', 'mae'], run_eagerly=True)

Then, I trained the model with:

history3 = model3.fit(x, y, epochs=1200)

And it proceeded correctly without any issues.

Hope this helps! 🙂
혹시 이해가 안되는 분있으면 메일주세요~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues regression issue To spot regression issues in latest version stat:awaiting tensorflower Status - Awaiting response from tensorflower TF2.14 For issues related to Tensorflow 2.14.x type:bug Bug
Projects
None yet
Development

No branches or pull requests

4 participants