-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient checkpointing usage #1
Comments
same question, differently formulated. Both dont seem to slow down training and GPU memory usage doesnt seem to have increased. Please let us know how to apply this.
|
@pidajay The issue here is that get_model is just called once. You want to apply the To elaborate on that and help with @Kokonut133's issue, what you should is to apply the decorator to a callable which has a decent amount of internal state that you want to avoid storing. Lets say I have a layer which takes in a X and outputs Y, but doesn't really have any internal activations (I believe Conv2D falls into this category). Then, all applying the @checkpointable decorator will do is save X and rerun the layer during backprop, so we don't get any memory savings. If I have 100 layers though, with inputs X1, X2, ..., X100, and outputs Y1, Y2, ..., Y100. You can break it up into blocks of 10 so that Please let me know if I can clarify this further. |
@davisyoshida Thanks for the clarification. I figured something like that needs to be done. I implemented a version where the user just has to specify a single decorator (like in my sample above) and then a closure grad function will do the recompute for every layer during the back prop (no activations stored). Will try to create a TF PR next week. Lets see how it goes. I will keep you in the loop. |
@pidajay I would be happy to see your implementation of this. If you could share a link, that would be great @davisyoshida So as I understand, I did apply it correctly in both examples presented? I would prefer using the bottom version. Yet it wont lead to memory saving due to the nature of Conv2D layers having little inner activations to store? Would you by chance know how to reduce memory usage with Conv2D layers? Thanks already in advance. |
@pidajay For graph mode, it's definitely possible to do with a one off decorator, but in eager mode, it wasn't clear to me that there's a nice way to know which outputs to save. If you discard all layer outputs during the forward pass, you won't actually end up saving any memory during the backwards pass. |
@Kokonut133 There's a few things to fix. First, you need to pass In the examples you gave, the constructed Conv2D layer isn't saved, so now there's no reference which will let us access its variables. I'd recommend something like the following (this could be done a bit more cleanly in the context of a keras First construct your layers: layer_blocks = []
for i in range(10):
block = []
for j in range(10):
block.append(Conv2D(filters, kernel_size=k_size, strides=2, padding="same"))
layer_blocks.append(block) Then make a function which will execute the block: @checkpointable
def run_block(block, inputs):
layer_output = inputs
for layer in block:
layer_output = layer(layer_output)
return layer_output Now you can execute your whole network block_output = inputs
for block in layer_blocks:
# In a keras Model, we could just use self.trainable_variables
watch = [v for layer in block for v in layer.trainable_variables]
block_output = run_block(block, block_output, _checkpoint=True, _watch_vars=watch) |
@davisyoshida From what I see it is actually possible in eager mode but a bit more involved. You just have to recompute outside the gradient tape so nothing gets saved. More details - during backward pass you need to recompute your current layer and previous layer's output and then calculate the gradients and feed them as output grads to the next layer for the backward pass. Essentially you are doing full back prop by yourself. As I said a bit more involved, but you abstract the user away from having to manually partition the model. Will try my best to find some time to push my code by end of this week. Will keep @Kokonut133 in the loop as well. Thanks! |
@pidajay My concern is about knowing what to save during the forward pass. For example, if I have a 100 layer network, I can save every 10th layer. If I have a 900 layer network, I can save every 30th. The problem is that in eager mode you can't tell which of those two situations you're in, so you won't know which layers should be saved and which should be dropped. On the other hand you can do a forward pass saving nothing, then a second forward pass where you save outputs, but that will take three forward passes instead of the two that gradient checkpointing usually takes. On the other hand if you mean recompute the previous layer's output by running the network from the beginning, the runtime will be quadratic in the depth of the network instead of linear, which is likely not desirable. |
So Itried implementing your solution and got to this:
checkpointable and _checkpoint should be fitting now, yet the trainable variable parameter I am unable to pinpoint where to put it. I cant refer to self inside of this as the entire function to build the network is nested in another class. I am sorry if I am asking for common python knowledge but it would be quite helpful. |
@Kokonut133 I think it's slightly coming down to a misunderstanding of how the decorator works. The important thing here is that you want to decorate a function that will be called during the training loop. However, the discriminator layer function here is only called once during model construction. I'll illustrate this below, using a single Conv2D layer as an example, but just to reiterate, there's no point in checkpointing a single layer, since you won't actually be saving any memory by doing so.
for batch in my_dataset:
output = my_decorated_callable(
batch,
_checkpoint=True,
_watch_vars=my_layer.trainable_variables
) Now if you want to do this with blocks of layers, look back to my earlier comment again. The important thing is that I've decorated the function that executes the block of layers, not the part of the code that creates those layers to begin with. By the way since this just grew out of something I was using for personal projects, I'm certainly open to suggestions for improvements both to usability, and documentation. |
@davisyoshida and @Kokonut133 . I have created my pull request over here - tensorflow/addons#1600. @davisyoshida maybe this answers the questions you had for me. Let me know. Note - I have just implemented the recompute functionality. I am yet to do the checkpointing. But with this approach I don't think that should be complicated. Anyways, a huge thanks for creating this repo! It helped clarify a lot things for me. I have added this repo as a reference in my PR. |
Ok. So, I believe to understand more. Yet, I am uncertain exactly how to apply it. I have a Model (tf model) which i train with a custom train function. Inside of this train function, I call self.model.train_on_batch(x,y,z) (a function from keras). Now do I have to add @checkpointable before the functions definition in the directory where train_on_batch is defined or is there a way to overwrite locally train_on_batch with checkpointable? I tried my_model.train_on_batch() = checkpointable(my_model.train_on_batch()) unsucessfully. Thank you for your help until now. |
@Kokonut133 To use the built in keras training options, the best thing to do would be make a custom Then, you can make a Sequential model from 10 of those compound layers (for a total of 100 conv layers). You definitely don't need to edit any Keras code. |
@pidajay What's the policy it uses for selecting which layer outputs to save? |
@davisyoshida the current implementation does not do checkpointing yet i.e save the outputs for specific layers. It just recomputes for every layer during backward pass. I am working on the saving part. But the first version should be straightforward. The user passes in number of checkpoints (say num_checkpoints=10 for a 100 layer network). I space the checkpoints evenly across the network (so every 10th layer will be saved). During backward pass I just recompute from the closest checkpoint. So when I am at layer 77, I just recompute from layer 70. |
Ah if it's only going to be for Keras models that's kinda a bummer. If possible it would be nice for something more finegrained if this is what's going to end up being the official tensorflow solution. Have you looked at the pytorch checkpoint function? |
The idea is to support various flavors of Keras (sequential, functional, etc) provided there is enough traction. I don't see a point moving outside Keras though when TF as a whole is moving towards Keras. |
Well in general, it's much easier to make a Keras wrapper around a non-Keras specific feature, than to take something that's Keras based and use it when you need to do something lower level. |
I tried to add the checkpointing to my GAN as follows:
This runs inside the training loop:
but no memory improvements at all :( am I doing anything wrong? @davisyoshida |
@left-brain Can you specify how you're benchmarking memory use? |
By the way, my IDE says: "Unused variable 'watch_args'" on this line:
|
Actually, for some reason I'm getting the following error: "No gradients provided for any variable" when running the code above... |
@left-brain The unused watch_args was just a mistake since I didn't delete that after an earlier refactor. As you can see the input args are watched on this line:
As for not getting any gradients, does this only occur when you are using the checkpointing decorator? I haven't seen that happen before, including when using the decorator with Keras models. |
@davisyoshida Does this Repo. works on graph mode? |
@nyngwang You can just use tensorflow's built in stuff for graph mode. I made this since at the time, eager mode wasn't supported. (Not sure about the situation now since I haven't been using tensorflow since I switched to JAX). |
@davisyoshida Did you mean the provided And now I'm facing a situation that needs some help from experts: I cannot tell whether those APIs (either the official one These are my questions: (all under the assumption that tensorflow 2.x is used, not 1.x versions)
|
@nyngwang I can't really say I'm an expert on GPU profiling, but the way I knew whether these things were or weren't working was whether or not I could use them to run stuff for models which should fit in memory with checkpointing, but fail without. For example if you use 100 blocks of 100 layers (with shared weights), you should run out of GPU memory if they're large enough, but checkpointing that block of 100 layers should cut the activation memory required enough to be able to train such a model in a relatively small amount of memory. |
I did. Unfortunately, it seems that the two cases, i.e. with/without checkpointing, cost the same GPU memory in my case. (I might incorrectly apply the related APIs, so I also submitted an issue to the official repo.) |
Hi @davisyoshida. I was looking into implementing my own version of gradient checkpointing in TF when I stumbled upon your repo. I tried to test your implementation but I was running into out of memory errors. Just wondering if I was using it as intended. Here is the code snippet. This is with the TF 2.2 nightly build.
The text was updated successfully, but these errors were encountered: