[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mega: Moving Average Equipped Gated Attention #21766

Merged
merged 66 commits into from
Mar 24, 2023

Conversation

mnaylor5
Copy link
Contributor

What does this PR do?

Fixes #19982

This pull request adds Mega: Moving Average Equipped Gated Attention, which is the current leader of the LRA benchmark. Adapted from the original fairseq-based repo and used a MLM checkpoint I created using the original implementation on the wikitext-103 dataset. There is no proposed Mega tokenizer, so I used the RoBERTa tokenizer which I used on the wikitext checkpoint. The proposed implementation works in encoder and decoder settings, and all relevant tests are passing.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker and @younesbelkada for text models; tagging @NielsRogge for visibility as he responded to the original issue.

…ntion type inputs; started work on downstream classes; removed mentions of position_ids
…ue where from_pretrained is renaming gamma and beta parameters
@mnaylor5
Copy link
Contributor Author

Alright @ArthurZucker, I think that's everything except the threads with ongoing discussion. I'm super happy with how this is shaping up! In the latest batch of commits:

  • Renamed classes, variables, and params based on comments (mainly in EMA and MovingAverageGatedAttention class)
  • Rearranged positional bias, normalization functions, activation functions, dropout classes
  • Added the copied from comments where requested
  • Added token type ID buffer
  • Added tests for generation and sequence classification
  • Moved FFT convolution into a reusable method with additional documentation
  • Addressed merge conflicts from LLaMA 🦙

Thanks for the feedback and I'll wait on any more changes until you get a chance to review the updates and resolve the open discussions. Excited to get up and running with MEGA in transformers 🚀 🤗

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow ! A lot of work, and I think we are almost there! I left a few nits here and there again, but should be ready soon. 🚀

@@ -121,6 +121,22 @@ def forward(self, x: Tensor) -> Tensor:
return torch.clip(gelu(x), self.min, self.max)


class AccurateGELUActivation(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 😉

src/transformers/models/mega/configuration_mega.py Outdated Show resolved Hide resolved
Comment on lines 110 to 117
ema_delta_alpha_range (`float`, *optional*, defaults to 0.2):
The standard deviation for initializing the delta (damping factor) and alpha (decay factor) parameters in
MultiDimensionDampedEMA.
ema_beta_range (`float`, *optional*, defaults to 0.02):
The standard deviation for initializing the beta parameter (expansion matrix) in MultiDimensionDampedEMA.
ema_gamma_omega_range (`float`, *optional*, defaults to 1.0):
The standard deviation for initializing the gamma (projection matrix) and omega (residual weight)
parameters in MultiDimensionEMA.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice 😉



# utility for causal LM masking in the format that Mega expects
def generate_causal_mask(seq_len):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a lot like the create_extended_attention_mask_for_decoder, you can use it in all the PretrainedModels ! (biggest difference seems to be that this one is not batched

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, great point! I didn't know about that method - super helpful 😄
Mega's attention methods expect a non-batched causal mask, so I can just index the one produced by the built-in method!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is addressed in the upcoming commit - I also realized that my old method was not handling the device. Since I haven't been able to test locally with a GPU, that had not triggered a test failure.

src/transformers/models/mega/modeling_mega.py Outdated Show resolved Hide resolved
src/transformers/models/mega/modeling_mega.py Outdated Show resolved Hide resolved
src/transformers/models/mega/modeling_mega.py Outdated Show resolved Hide resolved
src/transformers/models/mega/modeling_mega.py Outdated Show resolved Hide resolved
src/transformers/models/mega/modeling_mega.py Show resolved Hide resolved


@require_torch
class MegaModelIntegrationTest(TestCasePlus):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding that, looks good to me

@mnaylor5
Copy link
Contributor Author

@ArthurZucker as an update, it looks like the fix for left-padding is going to be a more significant effort to implement -- the relative bias is applied in the attention function, and it expects all of the inputs to be left-to-right starting at position 0. We can probably refactor to accept the position IDs like they did for CodeGen, but we'll also need to change how the bias is added since it is currently using a single (seq_len, seq_len) tensor for the entire batch. Refactoring that might be the heavier lift, but I'm still exploring.

I'll dig more into this tomorrow, but for the meantime, I've pushed updates that address the rest of your comments! If you have any other suggestions on the fix for relative positions, I'd love to hear them! 😄

@ArthurZucker
Copy link
Collaborator

Sure! Also it's not that important to have left padding in this PR, can be added in another PR!

@mnaylor5
Copy link
Contributor Author

Thanks @ArthurZucker! After digging into it, I do think it will require a pretty significant refactor to support left-padding in this PR. If you're comfortable with it, I agree that it could make sense in a new PR. I just added an entry in the MegaBlock docstring for the new causal_mask coming from the pretrained model's method, and added a missing device for the token type IDs.

Also pulled latest changes from main to hopefully prevent whatever was causing the tests for exotic models to fail. I'm really happy with how this is looking, so let me know if there's anything else needed to move forward with this PR! Appreciate your comments and guidance on everything so far! 🚀

@ArthurZucker
Copy link
Collaborator

Awesome, it's alright with me to leave this to another PR. Will do my final review before pinging @sgugger for another pair of eyes!

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done! 🔥
I left a few comments here ans there about naming conventions and dosctrings, but this is very detailed, and love that you took the time to adress all of my comments! Thanks for bearing with me 😉

src/transformers/models/mega/configuration_mega.py Outdated Show resolved Hide resolved
src/transformers/models/mega/configuration_mega.py Outdated Show resolved Hide resolved
src/transformers/models/mega/configuration_mega.py Outdated Show resolved Hide resolved
src/transformers/models/mega/configuration_mega.py Outdated Show resolved Hide resolved
src/transformers/models/mega/configuration_mega.py Outdated Show resolved Hide resolved
return output


class MovingAverageGatedAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MovingAverageGatedAttention(nn.Module):
class MegaMovingAverageGatedAttention(nn.Module):

Mega should prefix all the Mega classes (norms and MultiDimensionDampedEMA included) We are probabl also gonna rename MultiDimensionDampedEMA to MultiDimensionDampedEma! This is really a nit but users have an easier time guessing the names if we stick to camel everywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on this now and will add in a local commit to make sure it's updated wherever it's referenced

self.norm = MegaSequenceNorm(
self.config.normalization_type, self.config.hidden_size, affine=self.config.norm_affine
)
self.move = MultiDimensionDampedEMA(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.move = MultiDimensionDampedEMA(config)
self.ema_gate = MultiDimensionDampedEMA(config)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also going in a local commit as this will need to change in the weight conversion script

src/transformers/models/mega/modeling_mega.py Show resolved Hide resolved
Comment on lines 401 to 403
config.chunk_size = (
input_ids.size(1) * 2
) # we want the chunk size to be < sequence length, and the sequence length to be a multiple of chunk size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config.chunk_size = (
input_ids.size(1) * 2
) # we want the chunk size to be < sequence length, and the sequence length to be a multiple of chunk size
# we want the chunk size to be < sequence length, and the sequence length to be a multiple of chunk size
config.chunk_size = (input_ids.size(1) * 2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying this in the local commit because I'm not sure whether the parentheses are forcing the automated style checks to expand across lines


self.parent.assertEqual(result[0].shape, (self.batch_size, self.seq_length, self.hidden_size))

def check_chunking_shorter_sequence(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines 1013 to 1015
query_key, attention_gate = torch.split(
F.silu(query_key_gates), [self.config.shared_representation_size, self.config.intermediate_size], dim=-1
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also split this in 2 line (activation then this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in upcoming commit

Copy link
Collaborator
@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for all your work adding this model! My main comment is that all building blocks in the modeling file should be prefixed by Mega to avoid any name conflicts with other models.

README.md Outdated Show resolved Hide resolved
return embeddings


class SimpleRelativePositionalBias(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be prefixed by Mega

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in upcoming commit

return tile


class RotaryRelativePositionalBias(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be prefixed by Mega

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in upcoming commit

Comment on lines 262 to 268
NORM2FN = {
"layernorm": lambda embedding_dim, eps, affine: nn.LayerNorm(embedding_dim, eps, elementwise_affine=affine),
"scalenorm": lambda embedding_dim, eps, affine: ScaleNorm(dim=-1, eps=eps, affine=affine),
"rmsnorm": lambda embedding_dim, eps, affine: MegaRMSNorm(embedding_dim, eps=eps, affine=affine),
"batchnorm": lambda embedding_dim, eps, affine: nn.BatchNorm1d(embedding_dim, eps=eps, affine=affine),
"syncbatchnorm": lambda embedding_dim, eps, affine: nn.SyncBatchNorm(embedding_dim, eps=eps, affine=affine),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using lambda functions here will make MegaSequenceNorm and then the whole model unpicklable I fear. It's probably better to have five if/else in the MegaSequenceNorm module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad! I was the one pushing for this! Miscalculated the pickling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes sense! I originally had it as an if/else in the MegaSequenceNorm.__init__, so I'll just go back to that design

ALL_LAYERNORM_LAYERS.append(MegaSequenceNorm)


class MultiDimensionDampedEMA(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be prefixed by Mega

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in upcoming commit

Comment on lines 764 to 766
# Normalization modules
# copied from original Mega repo without modification except variable names
class ScaleNorm(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs to prefixed by Mega

Comment on lines 40 to 42
from transformers.models.mega.modeling_mega import (
MEGA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fits in one line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in upcoming commit

mnaylor5 and others added 5 commits March 23, 2023 10:59
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@mnaylor5
Copy link
Contributor Author

Thanks again @ArthurZucker and @sgugger! Appreciate the feedback, and it should all be addressed in the latest changes 🤗

Copy link
Collaborator
@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congrats on adding this new model to Transformers!

@sgugger sgugger merged commit 57f25f4 into huggingface:main Mar 24, 2023
@ArthurZucker
Copy link
Collaborator

Great working with you @mnaylor5 ! Congrats again on the merge 🔥

@NielsRogge
Copy link
Contributor

Congrats @mnaylor5 ! Feel free to share on social media and we'll amplify your post

@mnaylor5
Copy link
Contributor Author

Thanks so much @ArthurZucker and @NielsRogge! I learned a ton through this process, and it's so rewarding to see my code in a library I use so much ❤️

I posted something here on LinkedIn a couple days ago - I'll tag you guys in the comments as well!
https://www.linkedin.com/posts/mitchnaylor_mega-activity-7045103140890660864-9VOU

raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
* add mega file structure and plain pytorch version of mega source code

* added config class with old naming conventions

* filled in mega documentation

* added config class and embeddings with optional token types

* updated notes

* starting the conversion process, deleted intermediate and added use_cache back to config

* renamed config attributes in modeling_mega.py

* checkpointing before refactoring incremental decoding functions

* removed stateful incremental key/values for EMA and self-attention

* refactored MovingAverageGatedAttention to remove stateful k/v history and use unified attention mask

* MovingAverageGatedAttention works with incremental decoding + past values, added sequence length enforcement

* more comments in MovingAverageGatedAttention + checkpointing before GatedCrossAttention

* bug fix in attention mask handling in MovingAverageGatedAttention

* removed incremental state from GatedCrossAttention and removed IncrementalState class

* finished gated cross attention and got MegaLayer working

* fixed causal masking in mega decoder

* fixed how padding and causal masks are passed through MegaLayer with and without k/v caching

* finished MegaModel; tested with encoder, decoder-only, and cross-attention type inputs; started work on downstream classes; removed mentions of position_ids

* added optional dense hidden layer for masked and causal LM classes

* docstring updates in MultiHeadEMA and GatedCrossAttention, removed unnecessary inputs in cross-attention

* removed before_attn_fn in Mega class and updated docstrings and comments up to there

* bug fix in MovingAverageGatedAttention masking

* working conversion of MLM checkpoint in scratchpad script -- perfect matches

* moved arg for hidden dense layer in LM head to config; discovered issue where from_pretrained is renaming gamma and beta parameters

* renamed gamma and beta parameters to avoid HF renaming when loading from checkpoint

* finished checkpoint conversion script

* cleanup old class in mega config script

* removed 'copied from' statements and passing integration tests

* added num_attention_heads=1 to config for integration compatibility, decoder tests working, generation tests failing

* fixed tuple output of megamodel

* all common tests passing after fixing issues in decoder, gradient retention, and initialization

* added mega-specific tests, ready for more documentation and style checks

* updated docstrings; checkpoint before style fixes

* style and quality checks, fixed initialization problem in float_tensor, ready for PR

* added mega to toctree

* removed unnecessary arg in megaconfig

* removed unused arg and fixed code samples with leftover roberta models

* Apply suggestions from code review

Applied all suggestions except the one renaming a class, as I'll need to update that througout

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixed issue where .view breaks batch dimension, conversion script fixed with absolute imports, updated readme with Mega->MEGA

* removed asserts in Mega code, renamed sequencenorm, gatedcrossattention, and NFFN, replaced get_activation_fn with ACTFN, and added sequencenorm to layer norms

* reformatted .forward() docstrings to match style and removed unused mask input in cross-attention

* removed all reset_parameters() methods and rolled into MegaPreTrainedModel._init_weights()

* renamed all single-letter variables and improved readability in tensor size comments, Mega->MEGA in 2 documentation files

* variable names in NFFN

* manual Mega->MEGA changes in docs

* Mega->MEGA in config auto

* style and quality fixes

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* renamed parameters and variables with confusing names, added copied from statements, moved fft conv to its own method, other cleanup from PR comments

* commit before dealing with merge conflicts

* made new attention activation functions available in ACT2FN and added generation test from OPT

* style and quality in activations and tests

* documentation fixes, renaming variables in dropout and rotary positions, used built-in causal masking, encoders->layers in MegaModel, moved comments into docstrings

* style and quality fixes after latest updates, before rotary position ids

* causal mask in MegaBlock docstring + added missing device passing

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update README.md

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* added Mega prefixes where missing, reverted MegaSequenceNorm to if-else, other module renaming requested in PR

* style and quality fixes + readme updates pointing to main

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* add mega file structure and plain pytorch version of mega source code

* added config class with old naming conventions

* filled in mega documentation

* added config class and embeddings with optional token types

* updated notes

* starting the conversion process, deleted intermediate and added use_cache back to config

* renamed config attributes in modeling_mega.py

* checkpointing before refactoring incremental decoding functions

* removed stateful incremental key/values for EMA and self-attention

* refactored MovingAverageGatedAttention to remove stateful k/v history and use unified attention mask

* MovingAverageGatedAttention works with incremental decoding + past values, added sequence length enforcement

* more comments in MovingAverageGatedAttention + checkpointing before GatedCrossAttention

* bug fix in attention mask handling in MovingAverageGatedAttention

* removed incremental state from GatedCrossAttention and removed IncrementalState class

* finished gated cross attention and got MegaLayer working

* fixed causal masking in mega decoder

* fixed how padding and causal masks are passed through MegaLayer with and without k/v caching

* finished MegaModel; tested with encoder, decoder-only, and cross-attention type inputs; started work on downstream classes; removed mentions of position_ids

* added optional dense hidden layer for masked and causal LM classes

* docstring updates in MultiHeadEMA and GatedCrossAttention, removed unnecessary inputs in cross-attention

* removed before_attn_fn in Mega class and updated docstrings and comments up to there

* bug fix in MovingAverageGatedAttention masking

* working conversion of MLM checkpoint in scratchpad script -- perfect matches

* moved arg for hidden dense layer in LM head to config; discovered issue where from_pretrained is renaming gamma and beta parameters

* renamed gamma and beta parameters to avoid HF renaming when loading from checkpoint

* finished checkpoint conversion script

* cleanup old class in mega config script

* removed 'copied from' statements and passing integration tests

* added num_attention_heads=1 to config for integration compatibility, decoder tests working, generation tests failing

* fixed tuple output of megamodel

* all common tests passing after fixing issues in decoder, gradient retention, and initialization

* added mega-specific tests, ready for more documentation and style checks

* updated docstrings; checkpoint before style fixes

* style and quality checks, fixed initialization problem in float_tensor, ready for PR

* added mega to toctree

* removed unnecessary arg in megaconfig

* removed unused arg and fixed code samples with leftover roberta models

* Apply suggestions from code review

Applied all suggestions except the one renaming a class, as I'll need to update that througout

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixed issue where .view breaks batch dimension, conversion script fixed with absolute imports, updated readme with Mega->MEGA

* removed asserts in Mega code, renamed sequencenorm, gatedcrossattention, and NFFN, replaced get_activation_fn with ACTFN, and added sequencenorm to layer norms

* reformatted .forward() docstrings to match style and removed unused mask input in cross-attention

* removed all reset_parameters() methods and rolled into MegaPreTrainedModel._init_weights()

* renamed all single-letter variables and improved readability in tensor size comments, Mega->MEGA in 2 documentation files

* variable names in NFFN

* manual Mega->MEGA changes in docs

* Mega->MEGA in config auto

* style and quality fixes

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* renamed parameters and variables with confusing names, added copied from statements, moved fft conv to its own method, other cleanup from PR comments

* commit before dealing with merge conflicts

* made new attention activation functions available in ACT2FN and added generation test from OPT

* style and quality in activations and tests

* documentation fixes, renaming variables in dropout and rotary positions, used built-in causal masking, encoders->layers in MegaModel, moved comments into docstrings

* style and quality fixes after latest updates, before rotary position ids

* causal mask in MegaBlock docstring + added missing device passing

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update README.md

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* added Mega prefixes where missing, reverted MegaSequenceNorm to if-else, other module renaming requested in PR

* style and quality fixes + readme updates pointing to main

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add MEGA
6 participants