-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient checkpointing for TF keras models #38766
Labels
stat:awaiting tensorflower
Status - Awaiting response from tensorflower
type:feature
Feature requests
Comments
Does it make sense to port the above mentioned PR as a PR to TF core? Also tagging @seanpmorgan from TF addons to keep in loop. Thanks! |
This functionality would be very helpful for training memory-intensive language models! |
By the way I am thinking of addressing this with separate PRs as follows
Does this approach make sense? Thanks! |
@pidajay Did you make it? |
This was referenced Aug 18, 2022
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
stat:awaiting tensorflower
Status - Awaiting response from tensorflower
type:feature
Feature requests
System information
Describe the feature and the current behavior/state.
I have implemented a version of gradient check pointing for TF keras sequential models (with future plans to extend it for Keras functional API and custom models). The PR can be found here - tensorflow/addons#1600. I initially envisioned it as a package in TF addons repo but reviewers felt it was not right place and could potentially go in as a fix for the existing recompute_grad functionality in TF core -
tensorflow/tensorflow/python/ops/custom_gradient.py
Line 458 in 64f4a59
Here are the issues with the existing implementation of recompute_grad in TF core and my solutions to those
Solution - My PR provides a notebook tutorial to demonstrate how to use the implemented functionality.
Solution - My PR provides links to results with observed memory savings. Caveat - only CPU profiled results available. GPU and TPU results need to be done.
Solution - My PR expects no explicit partitioning of the model. The user just needs to add a single decorator to the model. That is it.
Solution - My PR implements the checkpointing functionality that allows the user to balance the tradeoff between memory and compute time.
Does it make sense to port the PR to TF core?
Will this change the current api? How?
The existing implementation can potentially be shoe horned into the existing API for recompute_grad if desirable.
Who will benefit with this feature?
Anyone who wants to train models in resource constrained environments.
Any Other info.
The text was updated successfully, but these errors were encountered: