[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TF port of BLIP #22090

Merged
merged 57 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
98b2afd
Initial commit
Rocketknight1 Mar 10, 2023
e557c34
more stash commit
Rocketknight1 Mar 14, 2023
87767b0
Yet another stash commit
Rocketknight1 Mar 20, 2023
d86ec34
yet more stash commit
Rocketknight1 Mar 21, 2023
35deb28
Mostly working except for docs / repo consistency
Rocketknight1 Mar 24, 2023
0a720e4
Stop importing model list from torch file
Rocketknight1 Mar 24, 2023
490fc63
Add TF BLIP models to docs
Rocketknight1 Mar 24, 2023
6dc06bb
Add auto classes
Rocketknight1 Mar 24, 2023
9fd4b76
Move get_text_features and get_image_features
Rocketknight1 Mar 24, 2023
07f99eb
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
8cfc37d
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
1c47a2f
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
70cfe55
Update src/transformers/models/blip/modeling_tf_blip_text.py
Rocketknight1 Mar 27, 2023
2024f5e
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
cc1694d
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
f31e96b
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
e12e305
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
2d622f6
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
ad2c87c
Update tests/models/blip/test_modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
6b781df
Update tests/models/blip/test_modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
dab565b
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
ee823fc
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
d6c5869
Update tests/models/blip/test_modeling_tf_blip_text.py
Rocketknight1 Mar 27, 2023
cf307fa
Update src/transformers/models/blip/modeling_tf_blip_text.py
Rocketknight1 Mar 27, 2023
0289c28
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Mar 27, 2023
c4a4b62
Use channels_last convolutions in TF (better performance + compatibil…
Rocketknight1 Mar 27, 2023
3a082f8
Remove _shape function
Rocketknight1 Mar 27, 2023
8e73e08
Move multi-line statement to one line in PT + TF
Rocketknight1 Mar 27, 2023
7d0f73b
Specify tf.keras.layers instead of importing from it
Rocketknight1 Mar 27, 2023
4ec371b
Remove test_gradient_checkpointing and empty test_training methods
Rocketknight1 Mar 27, 2023
561d2f8
move some multi-line statements to one line
Rocketknight1 Mar 27, 2023
076948b
Update docstring for generate
Rocketknight1 Mar 27, 2023
429c25e
Remove pruned heads set
Rocketknight1 Mar 27, 2023
3086257
Remove self.seq_len_dim
Rocketknight1 Mar 27, 2023
adb0330
Fixed issues with loss computation, should resolve some tests. Also e…
Rocketknight1 Mar 29, 2023
fba2385
ensure original model follows config in more cases
Rocketknight1 Mar 30, 2023
f6c328e
Skip the same cross-attention tests in the PT tests - didn't realize …
Rocketknight1 Mar 30, 2023
4d71a05
Add training args throughout the models and layers
Rocketknight1 Mar 30, 2023
7239db5
make fixup
Rocketknight1 Mar 30, 2023
09592b2
Fix docstring for inputs_embeds
Rocketknight1 Mar 30, 2023
d4a6fa6
Add docstring for is_decoder
Rocketknight1 Mar 30, 2023
60f078c
Add docstrings to text models
Rocketknight1 Mar 30, 2023
e6a7851
Remove redundant computation
Rocketknight1 Mar 30, 2023
f3062b1
Add unpack_inputs / keras_serializable
Rocketknight1 Mar 30, 2023
77e365e
Add modeling_tf_blip to doctests
Rocketknight1 Mar 30, 2023
6fff45c
Add config classes for keras serialization
Rocketknight1 Mar 30, 2023
34463ea
Changes to allow model porting with pt-to-tf
Rocketknight1 Mar 31, 2023
60b7fb7
Quick fix to decoder head and test tweaks
Rocketknight1 Apr 3, 2023
2a7f52d
Revert an issue with masking the embeddings outputs
Rocketknight1 Apr 3, 2023
d962ac6
Allow missing keys in some equivalence tests (for unused layers)
Rocketknight1 Apr 3, 2023
0a43f85
Add tf-pt equivalence tests back in
Rocketknight1 Apr 3, 2023
09095d1
Update src/transformers/models/blip/modeling_tf_blip.py
Rocketknight1 Apr 3, 2023
dd88c83
Update src/transformers/models/blip/modeling_tf_blip_text.py
Rocketknight1 Apr 3, 2023
d0fd3d4
Update src/transformers/models/blip/modeling_tf_blip_text.py
Rocketknight1 Apr 3, 2023
9efd53c
make fixup
Rocketknight1 Apr 3, 2023
afd5a9c
Refactor invert_attention_mask out into tf_utils
Rocketknight1 Apr 3, 2023
41fe5e1
Re-enable cross-tests on the PT side too
Rocketknight1 Apr 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Allow missing keys in some equivalence tests (for unused layers)
  • Loading branch information
Rocketknight1 committed Apr 3, 2023
commit d962ac6bd0ca164b5efc8320005c6a081cecfb97
65 changes: 2 additions & 63 deletions tests/models/blip/test_modeling_tf_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,66 +166,5 @@ def test_model_from_pretrained(self):
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)

# @unittest.skip(reason="This test class covers encoder-decoder models that the base test does not work with.")
def test_pt_tf_model_equivalence(self):
import transformers
import inspect
import tempfile
import os
import torch

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
self._make_attention_mask_non_null(inputs_dict)

pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

tf_model = model_class(config)
pt_model = pt_model_class(config)

tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
tf_inputs_dict_with_labels = self._prepare_for_class(
inputs_dict,
model_class,
# Not all models accept "labels" in the forward pass (yet :) )
return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False,
)

# For some models (e.g. base models), there is no label returned.
# Set the input dict to `None` to avoid check outputs twice for the same input dicts.
if not set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
tf_inputs_dict_with_labels = None
# Check we can load pt model in tf and vice-versa with model => model functions
breakpoint()
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=True)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# check with `labels`
if tf_inputs_dict_with_labels:
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# check with `labels`
if tf_inputs_dict_with_labels:
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
10 changes: 5 additions & 5 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))

@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
import transformers

for model_class in self.all_model_classes:
Expand Down Expand Up @@ -703,8 +703,8 @@ def test_pt_tf_model_equivalence(self):
tf_inputs_dict_with_labels = None

# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
Expand All @@ -716,11 +716,11 @@ def test_pt_tf_model_equivalence(self):
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
Expand Down