[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: Static cache as a standalone object #30476

Merged
merged 12 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working on generate too; make fixup
  • Loading branch information
gante committed Apr 29, 2024
commit 97aa785bac68c1bc01167b24420da8c9dec6c19d
7 changes: 7 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len

def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx] *= 0.0
self.value_cache[layer_idx] *= 0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.key_cache[layer_idx] *= 0.0
self.value_cache[layer_idx] *= 0.0
self.key_cache[layer_idx] = 0.0
self.value_cache[layer_idx] = 0.0

might be faster?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting to a new tensor produces a graph break 💔 (I'm assuming you meant self.key_cache[layer_idx] = torch.zeros(...))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No no, I think just filling them with zeros should work

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would result in TypeError: 'float' object is not subscriptable when indexing the cache :D

But filling with zeros with tensor.zero_() works 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok 👍🏻 let's go with that then!

41 changes: 29 additions & 12 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,34 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs

def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
"""
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.

Returns the resulting static cache object.
"""
needs_new_cache = (
not hasattr(self, "_static_cache")
or self._static_cache.max_batch_size < max_batch_size
or self._static_cache.max_cache_len < max_cache_len
)
if needs_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._static_cache = StaticCache(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
else:
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1526,18 +1554,7 @@ def generate(
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
model_kwargs["past_key_values"] = cache_cls(
config=self.config,
max_batch_size=batch_size,
max_cache_len=generation_config.max_length,
device=self.device,
dtype=cache_dtype,
)
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down
55 changes: 15 additions & 40 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,27 +732,6 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

for layer in self.model.layers:
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = layer.self_attn.o_proj.weight.dtype
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
)

def _reset_cache(self):
for layer in self.model.layers:
layer.self_attn.past_key_value = None


COHERE_INPUTS_DOCSTRING = r"""
Args:
Expand Down Expand Up @@ -980,7 +959,7 @@ def _update_causal_mask(
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
past_key_values: Cache,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
Expand All @@ -992,9 +971,12 @@ def _update_causal_mask(
return attention_mask
return None

if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
Comment on lines +976 to +981
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is new, and since we pass cahce position, let's use cache_position[0]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed in theory, can't do in practice: breaks torch.fx tests 💔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah thought so

if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
Expand All @@ -1003,9 +985,9 @@ def _update_causal_mask(
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
if using_static_cache:
target_length = past_key_values.get_max_length()
Comment on lines +990 to +991
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we always use get_max_length()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_max_length() is None in the dynamic caches

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be seq_length

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but alright

else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
Expand All @@ -1027,6 +1009,10 @@ def _update_causal_mask(
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
Expand Down Expand Up @@ -1184,13 +1170,6 @@ def prepare_inputs_for_generation(
use_cache=True,
**kwargs,
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
Expand All @@ -1208,8 +1187,7 @@ def prepare_inputs_for_generation(

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
Expand Down Expand Up @@ -1249,9 +1227,6 @@ def prepare_inputs_for_generation(
elif use_cache:
cache_position = cache_position[-input_length:]

if has_static_cache:
past_key_values = None

model_inputs.update(
{
"position_ids": position_ids,
Expand Down
36 changes: 16 additions & 20 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ def _update_causal_mask(
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
past_key_values: Cache,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
Expand All @@ -978,9 +978,12 @@ def _update_causal_mask(
return attention_mask
return None

if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
Expand All @@ -989,9 +992,9 @@ def _update_causal_mask(
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
Expand All @@ -1013,6 +1016,10 @@ def _update_causal_mask(
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
Expand Down Expand Up @@ -1166,13 +1173,6 @@ def prepare_inputs_for_generation(
use_cache=True,
**kwargs,
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
Expand All @@ -1190,8 +1190,7 @@ def prepare_inputs_for_generation(

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
Expand Down Expand Up @@ -1231,9 +1230,6 @@ def prepare_inputs_for_generation(
elif use_cache:
cache_position = cache_position[-input_length:]

if has_static_cache:
past_key_values = None

model_inputs.update(
{
"position_ids": position_ids,
Expand Down Expand Up @@ -1293,7 +1289,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
)
Expand Down Expand Up @@ -1807,7 +1807,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,7 +1525,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
Loading