[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partial implementation of lqlora #8324

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

Liebele
Copy link
@Liebele Liebele commented Apr 24, 2024

PR types

PR changes

Description

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
paddle-bot bot commented Apr 24, 2024

Thanks for your contribution!

Copy link
codecov bot commented Apr 24, 2024

Codecov Report

Attention: Patch coverage is 0% with 41 lines in your changes missing coverage. Please review.

Project coverage is 54.39%. Comparing base (0844a5b) to head (88b3455).

Files Patch % Lines
paddlenlp/peft/lora/lqlora_utils.py 0.00% 41 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8324      +/-   ##
===========================================
- Coverage    54.41%   54.39%   -0.03%     
===========================================
  Files          632      633       +1     
  Lines        99475    99516      +41     
===========================================
  Hits         54127    54127              
- Misses       45348    45389      +41     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ZHUI ZHUI requested a review from lugimzzz April 25, 2024 06:12
@Liebele
Copy link
Author
Liebele commented Apr 26, 2024

算法实现原理如下:
image

lora_A = Ur @ paddle.diag(paddle.sqrt(Sr))
lora_B = paddle.diag(paddle.sqrt(Sr)) @ Vhr

Q = qlora_weight_quantize_dequantize(W-lora_A@lora_B, double_quant=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double_quant=True,应该作为一个可调节参数,qlora_weight_quantize_dequantize中的其他参数也一样

Sr = S[:num_ranks]
Vhr = Vh[:num_ranks]

lora_A = Ur @ paddle.diag(paddle.sqrt(Sr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

配置的时候需要考虑lora scaling,看起来lora scaling只能强制为1


if W.dtype in [paddle.float16]:
old_dtype = W.dtype
W = paddle.cast(W, dtype=paddle.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast成fp32的原因?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lugimzzz
Copy link
Contributor

有没有实验结果可以参考一下效果

@lugimzzz
Copy link
Contributor
lugimzzz commented Apr 29, 2024

提交之前,修复格式问题

cd PaddleNLP
pre-commit install

@Liebele
Copy link
Author
Liebele commented May 6, 2024

在E2E数据集上的微调结果:
image


import paddle
from paddlenlp.quantization.qlora import qlora_weight_quantize_dequantize

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议把lqlora初始化的过程写成一个lqlora_init的函数,通过lora_config传入是否使用lqlora,考虑在621行前对lora_module apply这个lqlora_init,https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/peft/lora/lora_model.py#L621

@@ -477,6 +478,9 @@ def neft_post_hook(module, input, output):
else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)

if model_args.lqlora:
transform_lora_layers(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

传入到lora_config lqlora来控制

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants