[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static Cache: no mandatory cache_positions input #29221

Closed
wants to merge 12 commits into from

Conversation

gante
Copy link
Member
@gante gante commented Feb 22, 2024

What does this PR do?

Removes the hard requirement of cache_position from the public model classes (e.g. classes that can be loaded with AutoModel). Its contents can be derived from the cache's get_seq_length().


Performance and Quality checks

Slow tests run, getting the same outcome as in main:

  • RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv
  • RUN_SLOW=1 py.test tests/models/gemma/test_modeling_gemma.py -vv
  • RUN_SLOW=1 py.test tests/test_cache_utils.py -vv [Note: two new tests were added here]

👉 Note that these tests ensure that torch.compile works, with and without attention_mask being passed.

Local benchmark (RTX3090, tiny llama) -- no changes
(main)
Screenshot 2024-02-22 at 11 54 26

(this pr)
Screenshot 2024-02-22 at 17 57 13

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante marked this pull request as ready for review February 22, 2024 19:46
@gante
Copy link
Member Author
gante commented Feb 22, 2024

cc @fxmarty :)

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is the way we should go. For now have not seen issue with having a new argument, we added paddin_mask at some point.
IMO we should move to having cache_positions in generate as this is more explicit, easier to maintain and less error prone.
More than that, these operation are input dependent while being vectorized they add complexity when there should not be!

src/transformers/models/gemma/modeling_gemma.py Outdated Show resolved Hide resolved
@fxmarty
Copy link
Contributor
fxmarty commented Feb 23, 2024

No specific opinon, just that using cache_position in the modeling code when use_cache=False is not very intuitive to me:

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

Only thing I think: cache_position (if it remains) should be a strictly optional argument & the model should work without for BC.

@gante
Copy link
Member Author
gante commented Feb 23, 2024

EDIT -- before you read why I believe removing cache_positions from the signature is an important change for our users, have a look at the current diff :)

✅ much simpler computations (@ArthurZucker addresses your concerns, although the previous version also did not introduce slowdowns)
✅ working StaticCache.get_seq_length(), which removes a few exceptions
✅ no exception is raised when an input tensor is missing, with the StaticCache (with cache_positions on main, position_ids in the previous commit)
[There are still a few if/elses for cache-type-dependent logic in the model, but #29180 will standardize the cache API so that all caches match the static cache]

Updated performance numbers (i.e. still no degradation in throughput)

Screenshot 2024-02-23 at 19 18 56


No specific opinon, just that using cache_position in the modeling code when use_cache=False is not very intuitive to me:

@fxmarty I agree here, that it should be renamed :) e.g. data_position would be more appropriate, as it may be used without cache. Precise variable names help us in the long run.

IMO we should move to having cache_positions in generate as this is more explicit, easier to maintain and less error prone.
More than that, these operation are input dependent while being vectorized they add complexity when there should not be!

@ArthurZucker I strongly disagree, more redundant inputs result in more bugs 🐛 From the full input_ids and pad_token_id we can derive attention_mask, position_ids, and cache_positions -- it's not like we are giving more options to the users, there is only one set of valid tensors given the first two. In other words, we give users three different paths to get things wrong! As a result, we have to add safety warnings and exceptions, such as:

  1. warning if there is a pad token in the input but no attention mask
  2. raising an exception if cache positions are not passed
  3. raise an exception if the attention mask is not of the right type
  4. warning if both attention mask and pad token are missing in generate
    (plus other error-prone combinations that we are not even checking 😬)

This is arguably much more work than enforcing the correct pairing in the models themselves! On top of that, other functions like generate or pipeline have to create the correct logic anyways (e.g. create attn mask; create position ids), so why not push the logic to the models and prevent these bugs from occurring everywhere? 🤗

My goal with this PR is to iron out future sources of bugs before we roll out these changes to other models. Better than well-explained interfaces is... no need for that interface at all! Without these changes, users have to learn how to prepare a new input tensor to use static caches.

In terms of complexity, it will always exist. The difference is whether it exists outside generate (=users can get it wrong, one more argument in generate), inside generate (=we have to maintain the logic, users can get custom generation loops wrong, one more argument in forward), or inside forward (=we have to maintain the logic). In my experience, if some input can be used incorrectly, it WILL be used incorrectly. Having few bugs is one of the strong points of transformers, which we should strive to keep. [In the same direction, I think we should get rid of attention_mask and position_ids, but that would be a breaking change 😉 There are so many issues in this repo due to these variables 😭 I remember that at some point you were also trying to get rid of position_ids]

In terms of performance, we can see that it is negligible in eager forwards and none in compiled static forwards. Again, if not implemented here, it will be implemented in generate, resulting in the same net performance.

Comment on lines 401 to 403
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
return (self.key_cache[0, 0].any(dim=-1)).sum()
Copy link
Contributor
@fxmarty fxmarty Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this. This would be run at each forward. Couldn't we just increment somewhere, or not use get_seq_length at all in the forward?

Also given the issues we've had: did you check it works with compile (meaning: self.model.layers[0].past_key_values.get_seq_length() is correct inbetween each forward call in generate when the model is compiled with model = torch.compile(model, mode="reduce-overhead")? I removed the requirement on get_seq_length + seen_tokens as we were getting wrong values.

Copy link
Member Author
@gante gante Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty yes, this version works! e.g. try running RUN_SLOW=1 py.test tests/test_cache_utils.py::CacheIntegrationTest::test_static_cache_greedy_decoding_pad_left_0_eager on the latest commit, which confirms that all dynamic, eager static cache, and compiled static cache get the same sensible outputs.

I'm not a fan of this. This would be run at each forward. Couldn't we just increment somewhere?

I am very much aligned with your comment! Indeed that was my original plan by adding seen_tokens to the original cache implementation. However, as you mentioned, we haven't found a way to make it work at compile time. Perhaps we shall add a TODO to revisit this in the future, after the issue you opened on the PyTorch side has been addressed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • if part of the previous cache happens to be zero tensor (very rare in full precision, maybe not in half precision, we had a similar issue with llava when relying on tensors being zeros)
  • cost is ok I guess
  • it's totally implicit 😢 and that's less aligned with our overall philosophy.

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great explanation.
TLDR; I think part of this choice comes down to being implicit vs explicit.
Your approach works, but it's implicit. Both have ➕ and ➖ for maintainability.
I am leaning a lot more towards the explicit for a few reasons:

  1. If you generate the cache_positions in the model with torch.arange, the generations you will get with cuda graphs will be wrong precisely because the cache positions are not given as input. I tried this and it does not work.
  2. ➕ Instead of relying on hidden seen_tokens and 3 different get_length, get_usable_lenght, get_max_length, you just need cache_positions in generate.
  3. ➕ Stateless is less error prone, cache_positions is also pretty known from GPT-Fast.
  4. From the full input_ids and pad_token_id we can derive attention_mask, position_ids, and cache_positions
    Is not true: if you want to do paged attention / packed training with a custom 4D attention mask or anything a bit outside the classic forward (any custom Cache class), then you cannot only rely on input_ids and pad token.

  5. The generate function should keep track of where it is, for simplicity but also because otherwise your not really able to manipulate anything outside the forward of the model. So at the end of the day you have a loop in generate, which is the only public place where we use the past_key_values, but you don't let it handle the cache_positions which is a bit strange!

Let's keep the cache positions, gradually remove calls to get_seq_length, and explicitly pass the arguments to the model instead of relying on the model.

Comment on lines 401 to 403
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
return (self.key_cache[0, 0].any(dim=-1)).sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • if part of the previous cache happens to be zero tensor (very rare in full precision, maybe not in half precision, we had a similar issue with llava when relying on tensors being zeros)
  • cost is ok I guess
  • it's totally implicit 😢 and that's less aligned with our overall philosophy.


# `torch.compile`-friendly `torch.arange` from a shape
cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the cache positions are using internally anyways, then that means we still have to create them at some point.

@gante gante changed the title Static Cache: no cache_positions required in the public classes Static Cache: no mandatory cache_positions input Feb 28, 2024
@gante
Copy link
Member Author
gante commented Feb 28, 2024

@ArthurZucker as we talked on Slack, the latest set of commits makes cache_position follow the same logic as set for attention_mask or position_ids: enable both explicit [full control of forward] and implicit [lazy usage of forward] usage 🤗

NOTE: ignore the changes in prepare_inputs_for_generation for now, from what I'm seeing while enabling torch.compile we will need more touches there!

In summary, this PR:

  1. Enables implicit cache_position, compatible with torch.compile;
  2. Fixes StaticCache.get_seq_length() (needed for implicit cache_position in prepare_inputs_for_generation);
  3. Deprecates the seen_tokens attribute in the other caches, as it is redundant with cache_position;
  4. Slicing in prepare_inputs_for_generation with StaticCache now shares the original path, as opposed to having custom logic (a bunch of TODOs were added here, only feasible after future PRs);
  5. Adds docstrings for cache_position in forward;
  6. Adds several tests related to cache_position, to ensure it is working as expected.

Comment on lines +1125 to +1129
past_length = (
cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length()
)
max_cache_length = past_key_values.get_max_length()
cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This restructure prioritizes cache_position, falling back to .get_seq_length() in its absence. This replaces seen_tokens, which is now deprecated.

Note that past_length [all seen tokens] and cache_length [tokens in the cache] are both needed, otherwise SinkCache won't work.

Comment on lines +1165 to +1167
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
position_ids = position_ids.contiguous() if position_ids is not None else None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already on main (see here), not sure why this shows up 👀

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

@gante could you rebase / fix meges and I ll review!

@ArthurZucker ArthurZucker self-requested a review March 25, 2024 09:52
Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late review, not sure what is the status.
Overall the goal is to not have to pass cache positions, but the logic is very complicated, while asking users to pass the cache positions seems a lot simpler no?

Comment on lines +64 to +73
@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.40. Use the `cache_position` "
"variable instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice 😉

Comment on lines +866 to +873
if use_cache:
static_cache = getattr(self.layers[0].self_attn, "past_key_value", None)
if static_cache is not None:
past_seen_tokens = static_cache.get_seq_length()
else:
if not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a lot of work 😓
Does not seem like this is needed? Two cases:

  1. No cache positions -> not using generate or not using cache positions -> use the DynamicCache, thus the previous code works for the past length
  2. cache positions -> use them

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should go towards everybody should pass the cache positions and we should not use past_seen_tokens = static_cache.get_seq_length().


if cache_position is None:
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
# `torch.compile`-friendly `torch.arange` from a shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that also fix the ONNX export we had?

past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if use_cache:
static_cache = getattr(self.layers[0].self_attn, "past_key_value", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we broke AWQ a few times with this, let's check generation_config.cache_implementation ?

@gante
Copy link
Member Author
gante commented Apr 18, 2024

Closing this PR and other cache PRs, as we want to move in the opposite direction (static cache behaving like the other caches)

@gante gante closed this Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants