[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash in CLR optimizer callback #2172

Merged
merged 2 commits into from
Sep 25, 2020
Merged

Conversation

DavidWAbrahams
Copy link
Contributor
@DavidWAbrahams DavidWAbrahams commented Sep 22, 2020

I experienced the following stack trace when TF 2.4 nightly passes "step" as an int64. This happens, for example, when using the nightly version of TensorBoard.

Minimal repro: https://pastebin.com/Ni2dgDgh

Environment:
tf-nightly-gpu 2.4.0.dev20200917
tfa-nightly 0.12.0.dev20200918223509
tb-nightly 2.4.0a20200921

Description

File "...\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1117, in fit
callbacks.on_epoch_end(epoch, epoch_logs)
File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 427, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2274, in on_epoch_end
self._log_epoch_metrics(epoch, logs)
File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2316, in _log_epoch_metrics
train_logs = self._collect_learning_rate(train_logs)
File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2301, in _collect_learning_rate
logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
File "...\Python38\lib\site-packages\tensorflow_addons\optimizers\cyclical_learning_rate.py", line 94, in call
cycle = tf.floor(1 + step / (2 * step_size))
File "...\Python38\lib\site-packages\tensorflow\python\ops\variables.py", line 1074, in _run_op
return tensor_oper(a.value(), *args, **kwargs)
File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1155, in binary_op_wrapper
raise e
File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1139, in binary_op_wrapper
return func(x, y, name=name)
File "...\Python38\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
return target(*args, **kwargs)
File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1311, in truediv
return _truediv_python3(x, y, name)
File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1241, in _truediv_python3
raise TypeError("x and y must have the same dtype, got %r != %r" %
TypeError: x and y must have the same dtype, got tf.int64 != tf.float32

Brief Description of the PR:

Fixes # (issue)

Type of change

Checklist:

  • I've properly formatted my code according to the guidelines
    • By running Black + Flake8
    • By running pre-commit hooks
  • This PR addresses an already submitted issue for TensorFlow Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • This PR contains modifications to C++ custom-ops

How Has This Been Tested?

Run on
tf-nightly-gpu 2.4.0.dev20200917
and
tfa-nightly 0.12.0.dev20200918223509

My project code looks like:
clr = CyclicalLearningRate(initial_learning_rate=1e-3,
maximal_learning_rate=1e-2,
step_size=3*STEPS_PER_EPOCH_TRAIN,
scale_fn=lambda x:1.)
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = LAMB(clr)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

If you're adding a bugfix or new feature please describe the tests that you ran to verify your changes:
*

I experienced the following stack trace when TF 2.4 nightly passed "step" as an int64. Perhaps that is new behavior?

 File "...\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1117, in fit
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 427, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2274, in on_epoch_end
    self._log_epoch_metrics(epoch, logs)
  File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2316, in _log_epoch_metrics
    train_logs = self._collect_learning_rate(train_logs)
  File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2301, in _collect_learning_rate
    logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
  File "...\Python38\lib\site-packages\tensorflow_addons\optimizers\cyclical_learning_rate.py", line 94, in __call__
    cycle = tf.floor(1 + step / (2 * step_size))
  File "...\Python38\lib\site-packages\tensorflow\python\ops\variables.py", line 1074, in _run_op
    return tensor_oper(a.value(), *args, **kwargs)
  File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1155, in binary_op_wrapper
    raise e
  File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1139, in binary_op_wrapper
    return func(x, y, name=name)
  File "...\Python38\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1311, in truediv
    return _truediv_python3(x, y, name)
  File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1241, in _truediv_python3
    raise TypeError("x and y must have the same dtype, got %r != %r" %
TypeError: x and y must have the same dtype, got tf.int64 != tf.float32
@googlebot
Copy link

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@bot-of-gabrieldemarmiesse

@RaphaelMeudec

You are owner of some files modified in this pull request.
Would you kindly review the changes whenever you have the time to?
Thank you very much.

@googlebot
Copy link

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

@DavidWAbrahams
Copy link
Contributor Author

@googlebot I signed it!

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.

@@ -91,8 +91,9 @@ def __call__(self, step):
dtype = initial_learning_rate.dtype
maximal_learning_rate = tf.cast(self.maximal_learning_rate, dtype)
step_size = tf.cast(self.step_size, dtype)
cycle = tf.floor(1 + step / (2 * step_size))
x = tf.abs(step / step_size - 2 * cycle + 1)
step_as_dtype = tf.cast(step, dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

step = tf.cast(step, dtype)

This should conform to the naming convention in this script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was also my instinct, but it causes a unit test failure. ("step" is accessed later in the method and is expected to still have the original dtype)

I could do this without a local variable, just using "tf.cast(step, dtype)" where needed. Would that be better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DavidWAbrahams Thanks for clarification. Your approach is better.

Can you also share the minimal runnable code snippet to reproduce original issue? Thank you!

Copy link
Contributor Author
@DavidWAbrahams DavidWAbrahams Sep 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my minimal repro. It only triggers when I have added a TensorBoard callback. I guess that callback is somehow clobbering the dtype of "step".

https://pastebin.com/Ni2dgDgh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more detail, my versions are
tf-nightly-gpu 2.4.0.dev20200917
tfa-nightly 0.12.0.dev20200918223509
tb-nightly 2.4.0a20200921

@bhack
Copy link
Contributor
bhack commented Sep 22, 2020

Is this related to tensorflow/tensorflow#26407 right?

@DavidWAbrahams DavidWAbrahams changed the title Fix crash in CRL optimizer callback Fix crash in CLR optimizer callback Sep 22, 2020
@DavidWAbrahams
Copy link
Contributor Author

@bhack Thanks, yes that is probably what triggers my crash.

But even if that issue is eventually fixed, I think it's best if the cyclical learning rate callback sanitizes its inputs.

@WindQAQ WindQAQ self-requested a review September 25, 2020 01:14
Copy link
Member
@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@WindQAQ WindQAQ merged commit c7b867a into tensorflow:master Sep 25, 2020
jrruijli pushed a commit to jrruijli/addons that referenced this pull request Dec 23, 2020
* Fix crash in CRL optimizer callback

I experienced the following stack trace when TF 2.4 nightly passed "step" as an int64. Perhaps that is new behavior?

 File "...\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1117, in fit
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 427, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2274, in on_epoch_end
    self._log_epoch_metrics(epoch, logs)
  File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2316, in _log_epoch_metrics
    train_logs = self._collect_learning_rate(train_logs)
  File "...\Python38\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2301, in _collect_learning_rate
    logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
  File "...\Python38\lib\site-packages\tensorflow_addons\optimizers\cyclical_learning_rate.py", line 94, in __call__
    cycle = tf.floor(1 + step / (2 * step_size))
  File "...\Python38\lib\site-packages\tensorflow\python\ops\variables.py", line 1074, in _run_op
    return tensor_oper(a.value(), *args, **kwargs)
  File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1155, in binary_op_wrapper
    raise e
  File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1139, in binary_op_wrapper
    return func(x, y, name=name)
  File "...\Python38\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1311, in truediv
    return _truediv_python3(x, y, name)
  File "...\Python38\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1241, in _truediv_python3
    raise TypeError("x and y must have the same dtype, got %r != %r" %
TypeError: x and y must have the same dtype, got tf.int64 != tf.float32

* Attempt to fix unit test failure
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants