-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: SDPA FA2 path + static cache fix #30437
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -616,10 +616,13 @@ def forward( | |
cache_position: Optional[torch.LongTensor] = None, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
if output_attentions: | ||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. | ||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is | ||
# implemented. | ||
logger.warning_once( | ||
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " | ||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' | ||
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does " | ||
"not support `output_attentions=True`. Falling back to the manual attention implementation, " | ||
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " | ||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' | ||
) | ||
return super().forward( | ||
hidden_states=hidden_states, | ||
|
@@ -659,15 +662,19 @@ def forward( | |
if attention_mask is not None: | ||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] | ||
|
||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, | ||
# Reference: https://github.com/pytorch/pytorch/issues/112577. | ||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom | ||
# attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577. | ||
if query_states.device.type == "cuda" and causal_mask is not None: | ||
query_states = query_states.contiguous() | ||
key_states = key_states.contiguous() | ||
value_states = value_states.contiguous() | ||
|
||
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather | ||
# relying on the `is_causal` argument. | ||
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash | ||
# Attention 2 backend, rather relying on the `is_causal` argument. In that case, if using static cache, we need | ||
# to drop the empty KV entries | ||
# if causal_mask is None and cache_position is not None and isinstance(past_key_value, StaticCache): | ||
# key_states = key_states[:, :, : cache_position[-1] + 1, :] | ||
# value_states = value_states[:, :, : cache_position[-1] + 1, :] | ||
attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
|
@@ -1073,6 +1080,7 @@ def _update_causal_mask( | |
if self.config._attn_implementation == "sdpa": | ||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, | ||
# in order to dispatch on Flash Attention 2. | ||
breakpoint() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
if AttentionMaskConverter._ignore_causal_mask_sdpa( | ||
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens | ||
): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,9 +20,8 @@ | |
import pytest | ||
from parameterized import parameterized | ||
|
||
from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed | ||
from transformers import LlamaConfig, is_torch_available, set_seed | ||
from transformers.testing_utils import ( | ||
CaptureLogger, | ||
require_bitsandbytes, | ||
require_flash_attn, | ||
require_read_token, | ||
|
@@ -684,17 +683,18 @@ def test_model_13b_greedy_generation(self): | |
@require_torch_gpu | ||
@require_read_token | ||
def test_compile_static_cache(self): | ||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 | ||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943 | ||
Comment on lines
+686
to
+687
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can skipTest on torch version |
||
NUM_TOKENS_TO_GENERATE = 40 | ||
EXPECTED_TEXT_COMPLETION = { | ||
7: [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note you might get different results from A10 to T4, hence this dict. this change can lead to the push-important-models test + Slow tests to fail 😢 |
||
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.", | ||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", | ||
], | ||
8: [ | ||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", | ||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", | ||
], | ||
} | ||
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test | ||
# was changed to have a cache of 53 tokens (as opposed to 4096). | ||
EXPECTED_TEXT_COMPLETION = [ | ||
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " | ||
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " | ||
"of relativ", | ||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " | ||
"fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", | ||
] | ||
|
||
prompts = [ | ||
"Simply put, the theory of relativity states that ", | ||
|
@@ -706,38 +706,25 @@ def test_compile_static_cache(self): | |
) | ||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) | ||
|
||
def decode_one_tokens(model, cur_token, input_pos, cache_position): | ||
logits = model( | ||
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True | ||
)[0] | ||
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] | ||
return new_token | ||
|
||
batch_size, seq_length = inputs["input_ids"].shape | ||
with torch.no_grad(): | ||
model._setup_cache(StaticCache, 2, max_cache_len=4096) | ||
cache_position = torch.arange(seq_length, device=torch_device) | ||
generated_ids = torch.zeros( | ||
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device | ||
) | ||
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) | ||
|
||
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] | ||
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] | ||
generated_ids[:, seq_length] = next_token[:, 0] | ||
|
||
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) | ||
cache_position = torch.tensor([seq_length + 1], device=torch_device) | ||
for _ in range(1, NUM_TOKENS_TO_GENERATE): | ||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | ||
with CaptureLogger(logging.get_logger(__name__)) as cl: | ||
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) | ||
self.assertNotIn("skipping cudagraphs due to", cl.out) | ||
generated_ids[:, cache_position] = next_token.int() | ||
cache_position += 1 | ||
|
||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) | ||
# Dynamic Cache | ||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) | ||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) | ||
|
||
# Static Cache | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
) | ||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) | ||
|
||
# Static Cache + compile | ||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
) | ||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) | ||
Comment on lines
+709
to
+727
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't the static cache tests be somewhere like test_modeling_common.py and test every supported models? |
||
|
||
|
||
@require_torch | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
120 char limit OCD :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not in
make style
is it?transformers/pyproject.toml
Line 5 in a98c417