[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: Static cache as a standalone object #30476

Merged
merged 12 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tmp commit
  • Loading branch information
gante committed Apr 29, 2024
commit 56cf73b2230c146d5257ae2771aff6ff9584e0e8
7 changes: 4 additions & 3 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .configuration_utils import PretrainedConfig
from .utils import logging


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -370,11 +371,11 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.key_cache.append(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_key_cache)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache.append(new_layer_value_cache)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update(
self,
Expand Down
22 changes: 13 additions & 9 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
Expand Down Expand Up @@ -930,7 +929,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache | List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -959,7 +958,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -1023,7 +1022,9 @@ def forward(
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, DynamicCache) else next_decoder_cache
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
Expand Down Expand Up @@ -1089,6 +1090,10 @@ def _update_causal_mask(
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
Expand Down Expand Up @@ -1148,7 +1153,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache | List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -1263,8 +1268,7 @@ def prepare_inputs_for_generation(

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
Expand Down Expand Up @@ -1362,7 +1366,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache | List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -1479,7 +1483,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache | List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
Expand Down
8 changes: 4 additions & 4 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
import pytest
from parameterized import parameterized

from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
CaptureLogger,
require_bitsandbytes,
require_flash_attn,
require_read_token,
Expand Down Expand Up @@ -699,7 +698,7 @@ def test_compile_static_cache(self):

prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained(
Expand Down Expand Up @@ -744,7 +743,7 @@ def test_compile_repeated_calls(self):

prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained(
Expand Down Expand Up @@ -772,6 +771,7 @@ def test_compile_repeated_calls(self):
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)


@require_torch
class CodeLlamaIntegrationTest(unittest.TestCase):
PROMPTS = [
Expand Down