[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient checkpointing for TF keras models #38766

Open
pidajay opened this issue Apr 21, 2020 · 4 comments
Open

Gradient checkpointing for TF keras models #38766

pidajay opened this issue Apr 21, 2020 · 4 comments
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:feature Feature requests

Comments

@pidajay
Copy link
pidajay commented Apr 21, 2020

System information

  • TensorFlow version (you are using): 2.1
  • Are you willing to contribute it (Yes/No): Yes

Describe the feature and the current behavior/state.
I have implemented a version of gradient check pointing for TF keras sequential models (with future plans to extend it for Keras functional API and custom models). The PR can be found here - tensorflow/addons#1600. I initially envisioned it as a package in TF addons repo but reviewers felt it was not right place and could potentially go in as a fix for the existing recompute_grad functionality in TF core -

Here are the issues with the existing implementation of recompute_grad in TF core and my solutions to those

  1. Issue - Not usable. Most people have no idea how to use it. No docs or tutorials that explains how to use it.
    Solution - My PR provides a notebook tutorial to demonstrate how to use the implemented functionality.
  2. For people who did figure out how to use it, no memory savings was observed.
    Solution - My PR provides links to results with observed memory savings. Caveat - only CPU profiled results available. GPU and TPU results need to be done.
  3. There is probably an expectation that the user explicitly has to partition the model and decorate each partition. This is not user friendly and can make tasks such as transfer learning difficult.
    Solution - My PR expects no explicit partitioning of the model. The user just needs to add a single decorator to the model. That is it.
  4. No checkpointing functionality is implemented.
    Solution - My PR implements the checkpointing functionality that allows the user to balance the tradeoff between memory and compute time.

Does it make sense to port the PR to TF core?

Will this change the current api? How?
The existing implementation can potentially be shoe horned into the existing API for recompute_grad if desirable.

Who will benefit with this feature?
Anyone who wants to train models in resource constrained environments.

Any Other info.

@pidajay pidajay added the type:feature Feature requests label Apr 21, 2020
@pidajay
Copy link
Author
pidajay commented Apr 21, 2020

Does it make sense to port the above mentioned PR as a PR to TF core? Also tagging @seanpmorgan from TF addons to keep in loop. Thanks!

@gowthamkpr gowthamkpr assigned allenlavoie and unassigned gowthamkpr Apr 23, 2020
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 23, 2020
@allenlavoie allenlavoie assigned fchollet and unassigned allenlavoie Apr 23, 2020
@jarednielsen
Copy link
Contributor

This functionality would be very helpful for training memory-intensive language models!

@pidajay
Copy link
Author
pidajay commented Apr 27, 2020

By the way I am thinking of addressing this with separate PRs as follows

  1. Provide a fix (which I already have) and tutorial for the existing tf recompute_grad function. From what I understand, using this requires the user to manually partition the model which might be suboptimal in many cases.
  2. Provide a new decorator function '@recompute_sequential' that enables gradient checkpointing on Keras sequential models (like in the PR I have linked above). The big advantage over 1 is that the user does not have to manually partition the model.
  3. Provide another decorator function '@recompute_functional' that enables gradient checkpointing on Keras functions models and potentially sub-classed models.

Does this approach make sense? Thanks!

@nyngwang
Copy link
nyngwang commented Aug 8, 2022

@pidajay Did you make it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

8 participants