[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'LlamaForCausalLM' object has no attribute '_setup_cache' #31157

Closed
2 of 4 tasks
mobicham opened this issue May 31, 2024 · 3 comments
Closed
2 of 4 tasks

Comments

@mobicham
Copy link
Contributor

System Info

  • transformers version: 4.41.1
  • Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.23.2
  • Accelerate version: 0.30.1
  • PyTorch version (GPU?): 2.4.0.dev20240527+cu121 (True)

Who can help?

@ArthurZucker @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Seems like _setup_cache() was removed in the newer version of transformers. Is there an alternative to this?
It's necessary to do that call in order to properly compile the model for faster generation.
Thanks!

Expected behavior

model._setup_cache() should setup the cache.

@zucchini-nlp
Copy link
Member

Static Cache was moved to be a standalone object in #30476. You have to init StaticCache outside the model and pass in every forward call, similar to following:

past_key_values = StaticCache(model.config, bs, max_cache_length, model.device, model.dtype)
for i in range(max_new_tokens):
    out = model(next_token, past_key_values=past_key_values, return_dict=True, **model_kwargs)
    past_key_values = out.past_key_values
    next_token = sample_next_token(out)
    model_kwargs = update_model_kwargs(model_kwargs)

Also you can simply pass-in cache_implementation="static" to generate() (see docs):

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.max_new_tokens = 20
out = model.generate(input_ids, cache_implementation="static)

@mobicham
Copy link
Contributor Author
mobicham commented Jun 4, 2024

Static Cache was moved to be a standalone object in #30476. You have to init StaticCache outside the model and pass in every forward call, similar to following:

Thanks, will try that!

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.max_new_tokens = 20
out = model.generate(input_ids, cache_implementation="static)

That's not the same thing. This will compile the whole forward pass, which will force re-compilation every time the prompt length changes. I only want to compile the decoding part and leave the prefill phase uncompiled.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jul 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants