[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: Static cache as a standalone object #30476

Merged
merged 12 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
finalized mvp
  • Loading branch information
gante committed Apr 29, 2024
commit cd6aecd7c288088c097b61c199e7bdd188dff185
9 changes: 3 additions & 6 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Under the hood, `generate` will attempt to reuse the same cache object, removing
</hfoption>
<hfoption id="Static Cache">

A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache.

```py
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
Expand Down Expand Up @@ -142,11 +142,8 @@ text
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
```

Please note that the cache has to be manually reset if you want to repeat this process multiple times reusing the same cache object.

```py
past_key_values.reset() # Clears the cache's contents without destroying the objects
```
> [!TIP]
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method

</hfoption>
</hfoptions>
Expand Down
29 changes: 21 additions & 8 deletions tests/quantization/aqlm_integration/test_aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,12 @@ def test_quantized_model_compile(self):
# Sample tokens greedily
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)

Expand All @@ -209,12 +214,12 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
seq_length = input_ids.shape[1]

# Setup static KV cache for generation
if hasattr(self.quantized_model.config, "_pre_quantization_dtype"):
cache_dtype = self.quantized_model.config._pre_quantization_dtype
else:
cache_dtype = self.quantized_model.dtype
past_key_values = StaticCache(
config=self.quantized_model.config, max_batch_size=2, max_cache_len=seq_length + self.max_new_tokens + 1, device=torch_device, dtype=cache_dtype
config=self.quantized_model.config,
max_batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype,
)

# Allocate token ids to be generated and copy prefix ids
Expand All @@ -223,7 +228,13 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)

# Do a forward pass to fill the prefix cache and compile the kernels if necessary
logits = self.quantized_model(input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True)[0]
logits = self.quantized_model(
input_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
generated_ids[:, [seq_length]] = next_token

Expand All @@ -235,7 +246,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, self.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position, past_key_values)
next_token = decode_one_tokens(
self.quantized_model, next_token.clone(), None, cache_position, past_key_values
)
generated_ids.index_copy_(1, cache_position, next_token)
cache_position += 1

Expand Down
Loading