[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Bitsandbytes: Add dequantize API for bitsandbytes quantized models #30806

Merged
merged 11 commits into from
May 15, 2024

Conversation

younesbelkada
Copy link
Contributor
@younesbelkada younesbelkada commented May 14, 2024

What does this PR do?

Fixes #30177

This PR adds a new feature dequantize in order to de-quantize models for interesting usecases such as the one described in #30177

The API is very simple:

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer

model_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True))
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.dequantize()

text = tokenizer("Hello my name is", return_tensors="pt").to(0)

out = model.generate(**text)
print(tokenizer.decode(out[0]))

Users just need to make sure they have enough GPU RAM in order to store the unquantized model, otherwise they might face unexpected behaviour

Added the support for 4-bit / 8-bit models and nice tests + docs to educate users on how to use this new API.

cc @amyeroberts @SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this new method in quantizer ! This will make fine-tuning with quantized model way easier ! I left a few minor comments.

docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
src/transformers/integrations/__init__.py Outdated Show resolved Hide resolved
src/transformers/integrations/__init__.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/quantizers/quantizer_bnb_4bit.py Outdated Show resolved Hide resolved
src/transformers/quantizers/quantizer_bnb_4bit.py Outdated Show resolved Hide resolved
src/transformers/quantizers/quantizer_bnb_8bit.py Outdated Show resolved Hide resolved
Comment on lines 346 to 347
if cls_name == "Params4bit":
return bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
Copy link
Member
@SunMarc SunMarc May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user might want to know in which precision the model was dequantized since they don't have the possibility to control that. I think it could be great to give that information since there is no default value (as opposed to from_pretrained which loads the model in fp32).
Two ways to get that:

  • just check the dtype of the weights at the end ( potentially the easiest way )
  • check what happens in dequantize_4bit . In the method, you see that they get the output dtype with weight.quant_state.dtype.

We can potentially add a torch_dtype attribute in the future if it makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! The output dtype should be correctly inferred here: https://github.com/TimDettmers/bitsandbytes/blob/b891f80ba514833f41f0e9226983b02a9fb5c44b/bitsandbytes/functional.py#L1349 through the compute_dtype so it should be accurate - I added a warning_once staement to inform users on the dequantized dtype: 1a4a906

Copy link
Collaborator
@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! +1 on all of @SunMarc's comments.

tests/quantization/bnb/test_mixed_int8.py Outdated Show resolved Hide resolved
Comment on lines 346 to 347
if cls_name == "Params4bit":
return bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
import bitsandbytes as bnb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already imported at the top of the module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch ! Should be fixed now

younesbelkada and others added 5 commits May 15, 2024 13:02
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Copy link
Collaborator
@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this feature and iterating!

)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
Copy link
Collaborator
@amyeroberts amyeroberts May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One general comment, if instead you could have a private method _dequantize_and_replace, which handles the recursion, you don't need to return has_been_replaced here. When someone calls dequantize_and_replace, I don't think has_been_replaced is ever used and could be confusing e.g.:

# This is just dequantize_and_replace from before
def _dequantize_and_replace(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
):
    ...
    return model, has_been_replaced

def dequantize_and_replace(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
):
    model, has_been_replaced = _dequantize_and_replace(...)
    return model 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense ! Will do !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 8b904f7 !

src/transformers/integrations/bitsandbytes.py Outdated Show resolved Hide resolved
younesbelkada and others added 2 commits May 15, 2024 16:22
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
)

if not has_been_replaced:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

@younesbelkada younesbelkada merged commit 3f43582 into huggingface:main May 15, 2024
22 checks passed
@younesbelkada younesbelkada deleted the add-dequant branch May 15, 2024 15:17
@RonanKMcGovern
Copy link

Yeah this is great, thanks

@younesbelkada
Copy link
Contributor Author

Great thanks @RonanKMcGovern ! let us know how it goes

itazap pushed a commit that referenced this pull request May 24, 2024
…models (#30806)

* add  method

* change method name

* more comments

* Apply suggestions from code review

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fixup

* add docstrings and fix comment

* warn users on the de-quantized dtype

* Update src/transformers/quantizers/base.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/bitsandbytes.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final suggestion - use private method

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
…models (huggingface#30806)

* add  method

* change method name

* more comments

* Apply suggestions from code review

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fixup

* add docstrings and fix comment

* warn users on the de-quantized dtype

* Update src/transformers/quantizers/base.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/bitsandbytes.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final suggestion - use private method

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Load nf4 weights/model in bfloat16
5 participants