[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different results from multiple lora weight merge method #1904

Closed
2 of 4 tasks
daegonYu opened this issue Jul 3, 2024 · 4 comments
Closed
2 of 4 tasks

Different results from multiple lora weight merge method #1904

daegonYu opened this issue Jul 3, 2024 · 4 comments

Comments

@daegonYu
Copy link
daegonYu commented Jul 3, 2024

System Info

What is the difference between the two merge methods? The results will be different.

multiple lora weight merge method1

from peft import PeftMixedModel
from peft import PeftModel
from transformers import AutoModelForCausalLM

base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")

peft_model = PeftModel.from_pretrained(
    base_model , <adapter_path1>
)
peft_model  = peft_model.merge_and_unload()  

peft_model  = PeftModel.from_pretrained(
    peft_model  , <adapter_path2>
)

multiple lora weight merge method2

base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")
peft_model = PeftMixedModel.from_pretrained(base_model,<adapter_path1>)
peft_model.load_adapter(<adapter_path2>, adapter_name="other")
peft_model.set_adapter(["default", "other"])

Below is a list of python and library versions.

Python : 3.10.13

Package Version


accelerate 0.31.0
aiohttp 3.9.5
aiosignal 1.3.1
annotated-types 0.7.0
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
beir 2.0.0
certifi 2024.6.2
charset-normalizer 3.3.2
comm 0.2.2
contourpy 1.2.1
cycler 0.12.1
datasets 2.20.0
debugpy 1.8.1
decorator 5.1.1
dill 0.3.8
docstring_parser 0.16
einops 0.8.0
elasticsearch 7.9.1
evaluate 0.4.2
exceptiongroup 1.2.0
executing 2.0.1
faiss-cpu 1.8.0.post1
filelock 3.15.4
fire 0.6.0
flash-attn 2.5.9.post1
fonttools 4.53.0
frozenlist 1.4.1
fsspec 2024.5.0
huggingface-hub 0.23.4
idna 3.7
importlib_metadata 7.2.1
ipykernel 6.29.4
ipython 8.25.0
jedi 0.19.1
Jinja2 3.1.4
joblib 1.4.2
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
jupyter_client 8.6.2
jupyter_core 5.7.2
kiwisolver 1.4.5
llm2vec 0.1.7
MarkupSafe 2.1.5
matplotlib 3.9.0
matplotlib-inline 0.1.7
mistral_common 1.2.1
mistral_inference 1.1.0
mpmath 1.3.0
multidict 6.0.5
multiprocess 0.70.16
nest_asyncio 1.6.0
networkx 3.3
numpy 2.0.0
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.5.40
nvidia-nvtx-cu12 12.1.105
packaging 24.1
pandas 2.2.2
parso 0.8.4
peft 0.11.1
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.3.0
pip 24.0
platformdirs 4.2.2
prompt_toolkit 3.0.47
protobuf 5.27.1
psutil 6.0.0
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 16.1.0
pyarrow-hotfix 0.6
pydantic 2.6.1
pydantic_core 2.16.2
Pygments 2.18.0
pyparsing 3.1.2
python-dateutil 2.9.0
pytrec_eval 0.5
pytz 2024.1
PyYAML 6.0.1
pyzmq 26.0.3
referencing 0.35.1
regex 2024.5.15
requests 2.32.3
rpds-py 0.18.1
safetensors 0.4.3
scikit-learn 1.5.0
scipy 1.13.1
sentence-transformers 3.0.1
sentencepiece 0.1.99
setuptools 70.1.0
simple_parsing 0.1.5
six 1.16.0
stack-data 0.6.2
sympy 1.12.1
termcolor 2.4.0
threadpoolctl 3.5.0
tokenizers 0.19.1
torch 2.3.0
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
transformers 4.41.2
triton 2.3.0
typing_extensions 4.12.2
tzdata 2024.1
urllib3 2.2.2
wcwidth 0.2.13
wheel 0.43.0
xformers 0.0.26.post1
xxhash 3.4.1
yarl 1.9.4
zipp 3.19.2

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

This is explained in the question above.

Expected behavior

This is explained in the question above.

@BenjaminBossan
Copy link
Member

In your first snippet, you merge the first adapter weights into the base weights, then load the second adapter. The second method does not merge the weights at all, it only activates the two adapters. In both cases, the end result should be that the two adapters are active.

Due to floating point precision, merging will always lead to slightly different results than not merging. Did you find big differences or only small ones? This can be best detected by inspecting the logits.

@daegonYu
Copy link
Author
daegonYu commented Jul 4, 2024

As you said, there is a small difference between before and after merge_and_unload() due to floating point precision, but there is a large difference in logit between the two methods above. Can you tell the difference between the two merge methods? (A document worth referencing is also good) Unfortunately, the document below does not provide a detailed explanation. https://huggingface.co/docs/peft/developer_guides/mixed_models

@BenjaminBossan
Copy link
Member

Can you tell the difference between the two merge methods?

As I mentioned earlier, the second method does not involve any merging at all. It is possible to apply LoRA without merging, by calculating the delta activations from LoRA and adding it to the activations of the base model. Mathematically, those operations are the same but in practice they are not due to floating point imprecisions. Also, without merging, more compute is required, so it should be a bit slower.

there is a large difference in logit between the two methods above

I tried your two methods with a local adapter and found that the results are pretty much identical:

import torch
from peft import PeftModel, PeftMixedModel, get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM

model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
inputs = torch.arange(10).view(-1, 1)
logits_base = model(inputs).logits[0, 0]
print("base")
print(logits_base)
# tensor([-3.9534, -3.9515,  3.2369,  ..., -3.9653, -3.9600, -4.0763],
#       grad_fn=<SelectBackward0>)

# create adapters
lora_config = LoraConfig(init_lora_weights=False)  # <= avoid zero init
torch.manual_seed(0)
model = get_peft_model(model, lora_config)
model.save_pretrained("/tmp/peft/1904/lora1")
del model

model = AutoModelForCausalLM.from_pretrained(model_id)
lora_config = LoraConfig(init_lora_weights=False)  # <= avoid zero init
torch.manual_seed(1)
model = get_peft_model(model, lora_config)
model.save_pretrained("/tmp/peft/1904/lora2")
model.eval()
del model

# method1
model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, "/tmp/peft/1904/lora1")
model = model.merge_and_unload()
model = PeftModel.from_pretrained(model, "/tmp/peft/1904/lora2")
model.eval()
logits_merged = model(inputs).logits[0, 0]
print("merged")
print(logits_merged)
# tensor([-1.9932, -2.0046,  1.5403,  ..., -2.0966, -2.0581, -2.0707])
del model

# method2
model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftMixedModel.from_pretrained(model, "/tmp/peft/1904/lora1")
model.load_adapter("/tmp/peft/1904/lora2", adapter_name="other")
model.set_adapter(["default", "other"])
model.eval()
logits_multi = model(inputs).logits[0, 0]
print("multi")
print(logits_multi)
# tensor([-1.9932, -2.0046,  1.5403,  ..., -2.0966, -2.0581, -2.0707],
#        grad_fn=<SelectBackward0>)

The results are pretty much identical, I checked the tolerance and 1e-5 passes but 1e-6 fails due to floating points. Could you check if you can reproduce these results?

@daegonYu
Copy link
Author
daegonYu commented Jul 5, 2024

I'm sorry. My experiment turned out to be wrong. I also ran the above code and confirmed the same thing. thank you for telling me.

@daegonYu daegonYu closed this as completed Jul 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants