[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pruning: Keras subclassed model increased support #155

Open
alanchiao opened this issue Nov 11, 2019 · 7 comments
Open

Pruning: Keras subclassed model increased support #155

alanchiao opened this issue Nov 11, 2019 · 7 comments
Assignees
Labels
contributions welcome See CONTRIBUTING.md for details. feature request feature request technique:pruning Regarding tfmot.sparsity.keras APIs and docs

Comments

@alanchiao
Copy link
alanchiao commented Nov 11, 2019

Currently the pruning API will throw an error when a subclassed model is passed to it. Users can get around this by diving into the subclassed models and applying pruning to individual Sequential/Functional models and tf.keras layers.

Better support is important for various cases (e.g. Object Detection, BERT examples) and issues such as this one.

We can provide better support for pruning an entire subclassed model.

  • Pruning some layers of the model would still require going into the model definition itself, though now you can prune a whole subclassed
    model inside a subclassed model.
  • This would only prune variables that live inside a tf.keras.Layer (whether a built-in layer or a custom layer using a the PrunableLayer interface).

Implementation-wise, we can iterate through the layers of a subclassed model (and nested models) and applying pruning to all of them. Replacing a layer in an already created model will be tricky and we'd have to do this without clone_model.

@alanchiao alanchiao self-assigned this Nov 11, 2019
@alanchiao alanchiao added duplicate This issue or pull request already exists feature request feature request and removed duplicate This issue or pull request already exists labels Dec 4, 2019
@alanchiao alanchiao added the contributions welcome See CONTRIBUTING.md for details. label Jan 14, 2020
@alanchiao alanchiao removed their assignment Jan 14, 2020
@tensorflow tensorflow deleted a comment from angkana380202 Jan 14, 2020
@alanchiao alanchiao added the technique:pruning Regarding tfmot.sparsity.keras APIs and docs label Feb 6, 2020
@nutsiepully
Copy link
Contributor

It seems to me we are mixing a few issues. I want to make sure I understand the problem correctly. Please correct me if I'm wrong.

Issue is about handling models recursively - a Keras Model which contains another model, not about subclass models.

If one of the layers in the model is Sequential, or Functional we should ideally traverse that model further and prune all the layers within them.

So, if the user provides the the following.

model1 = tf.keras.Sequential([tf.keras.layers.Dense(3)])
model2 = tf.keras.Sequential([model1, tf.keras.layers.Dense(2)])

we should prune both the Dense layers, and recursively traverse the model. I agree with that.

I am not quite sure how the subclass models play a role here. It is possible a user combines a keras Sequential, Functional and Layer within a subclass model. In that case, the user has to explicitly prune whatever they want.

Since we can't reliably clone a subclass model, it isn't possible for us to apply pruning to the model.

@Hackerman28
Copy link

Hi @nutsiepully. I am currently having problems with pruning nested keras models. My model is based on the matterport's mask-rcnn repo. I have a model which contains an inner model which is called multiple times within the outer model but has shared weights.
The inner model definition is defined here.
So when I'm using prune_low_magnitude API for pruning, its giving me the following error
ValueError: Please initialize Prune with a supported layer. Layers should either be a PrunableLayer instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.functional.Functional'>
So I defined a clone_function which returns prune_low_magnitude(inner_model).
def clone_func(layer):
if isinstance(layer, tensorflow.python.keras.engine.base_layer.AddLoss):
return layer
if isinstance(layer, tensorflow.python.keras.engine.base_layer.AddMetric):
return layer
if isinstance(layer, tensorflow.python.keras.engine.functional.Functional):
return return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
In the above code when the inner functional model is passed as layer, it is called with prune_low_magnitude API to prune the inner model or that's what I hoped for. But I'm getting the following error.
Traceback (most recent call last):
File "nucoco.py", line 510, in
augmentation=augmentation
File "/backup/Radar-RGB-Attentive-Multimodal-Object-Detection/Radar_RGB_Camera_Object_Detection/mrcnn/model.py", line 2972, in prune_train
self.keras_model = tf.keras.models.clone_model(self.keras_model, clone_function=clone_func)
File "/home/mcw/miniconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/models.py", line 429, in clone_model
model, input_tensors=input_tensors, layer_fn=clone_function)
File "/home/mcw/miniconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/models.py", line 201, in _clone_functional_model
created_layers=created_layers))
File "/home/mcw/miniconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py", line 1214, in reconstruct_from_config
process_node(layer, node_data)
File "/home/mcw/miniconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py", line 1162, in process_node
output_tensors = layer(input_tensors, **kwargs)
UnboundLocalError: local variable 'kwargs' referenced before assignment
Can u help me with this error?

@teijeong
Copy link
Contributor

Hi @liyunlu0618 , can you check the current status?

@liyunlu0618
Copy link
Contributor

We recently added support for pruning nested models, see this PR.

For subclass models, since keras doesn't support cloning, we still don't have a model-level API. You can still re-construct the model and wrapper the layers to prune with the pruning API.

@didadida-r
Copy link

it seems subclass prune still not support in 0.7.1

@gnhearx
Copy link
gnhearx commented Mar 17, 2022

Good day everyone :)

I would also like to chip in to this discussion as I too have been struggling to get Nested Model Pruning to work correctly.

I couldn't help noticing that support has been added but for some reason my implementation does not seem to work.
Expected behaviour:

  • Create a model that is a composition of nested models
  • Define the pruning scheduling
  • Clone the model layer by layer and set the pruning wrappers for each.
  • Train the model in a fine tuning manner to allow pruning to happen.
  • Remove the wrappers and then re-evaluate how many weights are now converted to zeroes. (sparsity).
  • Expect to see the sparsity increased between the pruned model and the original model before pruning.

Observed behaviour:

  • The pruning wrappers are successfully applied to all layers in the nested model, as expected they should be.
  • However, after pruning and removing of the wrappers, there was no change at all to the sparsity of the pruned model.

It seems to me like the pruning is not even taking place. However, if I instantiate a model without nested models and run the exact same logic pipeline, then pruning acts exactly as expected and all layers are pruned successfully.

Am I perhaps missing a step that is not mentioned in the documentation for tensorflow/keras model pruning? Any help would be greatly appreciated.

Side notes:

  • The nested model is a pre trained VGG16 with the top removed from keras.
  • There are actually no subclassed layers that are being pruned in my model. -> its all functional models contained in tensorflow/keras.

@LIU-FAYANG
Copy link

@alanchiao Hi, recently I'm working on pruning subclass models and the model seems not able to converge with pruning. I tired to apply the prune_low_magnitude API directly on the layers to be pruned within the subclasses and the pruning schedule applied to each layer is the same.
#965 mentioned pruning callback might cause this issue if we use this as a workaround method, could you share your opinion on this? Is there any other things that I need to do to use the workaround method you mentioned to prune a subclass model?
I think I missed out some steps to make this workaround method work. Thanks for your kind help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome See CONTRIBUTING.md for details. feature request feature request technique:pruning Regarding tfmot.sparsity.keras APIs and docs
Projects
None yet
Development

No branches or pull requests

8 participants