[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sliding_window shouldn't be applied when flash_attn not installed? #680

Open
rossbm opened this issue Jun 21, 2024 · 2 comments
Open

sliding_window shouldn't be applied when flash_attn not installed? #680

rossbm opened this issue Jun 21, 2024 · 2 comments
Labels
currently fixing Am fixing now!

Comments

@rossbm
Copy link
rossbm commented Jun 21, 2024

I've been finetuning unsloth/Phi-3-mini-4k-instruct-bnb-4bit with a T4, which doesn't support flash attention, so I don't have it installed.

During evaluation, I've been running into the following error:

File /anaconda/envs/text2text-tagger/lib/python3.11/site-packages/unsloth/models/llama.py:218, in LlamaAttention_fast_forward_inference(self, hidden_states, past_key_value, position_ids, do_prefill, attention_mask)
    216     A = torch.matmul(A, Vnn, out = Qn)
    217 else:
--> 218     A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
    219 pass
    220 A = A.transpose(1, 2)

RuntimeError: The expanded size of the tensor (2047) must match the existing size (2956) at non-singleton dimension 3.  Target sizes: [2, 32, 1, 2047].  Tensor sizes: [2, 1, 1, 2956]

The batch that is being evaluated at this point has 2955 tokens. However, unsloth/Phi-3-mini-4k-instruct-bnb-4bit should support sequence lengths of 4096 tokens, and I make certain to set max_seq_length to 4096 when initializing the model.

Looking through the model config for unsloth/Phi-3-mini-4k-instruct-bnb-4bit, I see sliding_window": 2048, which would be the only place that a length of 2048 (or 2047) would be coming from.

In:

if sliding_window is not None and kv_seq_len > sliding_window:
, we have: if sliding_window is not None and kv_seq_len > sliding_window:

However, in https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py, there's a check if flash_attn is installed and and supports a sliding window:

# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
# if is_flash_attn_2_available():
_flash_supports_window_size = False
try:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

    _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
except ImportError as error:
    logger.warning(
        f"`flash-attention` package not found, consider installing for better performance: {error}."
    )
    if not _flash_supports_window_size:
        logger.warning(
            "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
        )

before the sliding window is used:

       use_sliding_windows = (
            _flash_supports_window_size
            and getattr(self.config, "sliding_window", None) is not None
            and kv_seq_len > self.config.sliding_window
        )

Sure enough, when I set model.config.sliding_window = 10_000 I am able to successfully call model.generate() on the batch that was giving me the RuntimeError: The expanded size of the tensor (2047) ... error.

So I think that the solution is to update if sliding_window is not None and kv_seq_len > sliding_window: to check if flash-attention is installed and supports window size, similar to what phi-3 is doing.

@rossbm
Copy link
Author
rossbm commented Jun 24, 2024

I've tried running on another VM where I've installed flash_attn, but I'm still getting the error. Maybe the issue is that the slicing tokens aren't being applied to the attention mask.

From https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py

if sliding_window is not None and kv_seq_len > sliding_window:
        # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
        slicing_tokens = 1 - sliding_window
        Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
        Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()

While in https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py and https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py we have:

                if attention_mask is not None:
                    attention_mask = attention_mask[:, slicing_tokens:]
                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)

@danielhanchen
Copy link
Contributor

@rossbm Much apologies on the delay - my bro and I just relocated to SF, so sorry on the delay - appreciate the investigation as well!

I shall check if I'm doing inference on SWAs correctly :) Thanks for the report!

@danielhanchen danielhanchen added the currently fixing Am fixing now! label Jul 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
currently fixing Am fixing now!
Projects
None yet
Development

No branches or pull requests

2 participants