-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Different results from multiple lora weight merge method #1904
Comments
In your first snippet, you merge the first adapter weights into the base weights, then load the second adapter. The second method does not merge the weights at all, it only activates the two adapters. In both cases, the end result should be that the two adapters are active. Due to floating point precision, merging will always lead to slightly different results than not merging. Did you find big differences or only small ones? This can be best detected by inspecting the logits. |
As you said, there is a small difference between before and after merge_and_unload() due to floating point precision, but there is a large difference in logit between the two methods above. Can you tell the difference between the two merge methods? (A document worth referencing is also good) Unfortunately, the document below does not provide a detailed explanation. https://huggingface.co/docs/peft/developer_guides/mixed_models |
As I mentioned earlier, the second method does not involve any merging at all. It is possible to apply LoRA without merging, by calculating the delta activations from LoRA and adding it to the activations of the base model. Mathematically, those operations are the same but in practice they are not due to floating point imprecisions. Also, without merging, more compute is required, so it should be a bit slower.
I tried your two methods with a local adapter and found that the results are pretty much identical: import torch
from peft import PeftModel, PeftMixedModel, get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
inputs = torch.arange(10).view(-1, 1)
logits_base = model(inputs).logits[0, 0]
print("base")
print(logits_base)
# tensor([-3.9534, -3.9515, 3.2369, ..., -3.9653, -3.9600, -4.0763],
# grad_fn=<SelectBackward0>)
# create adapters
lora_config = LoraConfig(init_lora_weights=False) # <= avoid zero init
torch.manual_seed(0)
model = get_peft_model(model, lora_config)
model.save_pretrained("/tmp/peft/1904/lora1")
del model
model = AutoModelForCausalLM.from_pretrained(model_id)
lora_config = LoraConfig(init_lora_weights=False) # <= avoid zero init
torch.manual_seed(1)
model = get_peft_model(model, lora_config)
model.save_pretrained("/tmp/peft/1904/lora2")
model.eval()
del model
# method1
model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, "/tmp/peft/1904/lora1")
model = model.merge_and_unload()
model = PeftModel.from_pretrained(model, "/tmp/peft/1904/lora2")
model.eval()
logits_merged = model(inputs).logits[0, 0]
print("merged")
print(logits_merged)
# tensor([-1.9932, -2.0046, 1.5403, ..., -2.0966, -2.0581, -2.0707])
del model
# method2
model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftMixedModel.from_pretrained(model, "/tmp/peft/1904/lora1")
model.load_adapter("/tmp/peft/1904/lora2", adapter_name="other")
model.set_adapter(["default", "other"])
model.eval()
logits_multi = model(inputs).logits[0, 0]
print("multi")
print(logits_multi)
# tensor([-1.9932, -2.0046, 1.5403, ..., -2.0966, -2.0581, -2.0707],
# grad_fn=<SelectBackward0>) The results are pretty much identical, I checked the tolerance and 1e-5 passes but 1e-6 fails due to floating points. Could you check if you can reproduce these results? |
I'm sorry. My experiment turned out to be wrong. I also ran the above code and confirmed the same thing. thank you for telling me. |
System Info
What is the difference between the two merge methods? The results will be different.
multiple lora weight merge method1
multiple lora weight merge method2
Below is a list of python and library versions.
Python : 3.10.13
Package Version
accelerate 0.31.0
aiohttp 3.9.5
aiosignal 1.3.1
annotated-types 0.7.0
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
beir 2.0.0
certifi 2024.6.2
charset-normalizer 3.3.2
comm 0.2.2
contourpy 1.2.1
cycler 0.12.1
datasets 2.20.0
debugpy 1.8.1
decorator 5.1.1
dill 0.3.8
docstring_parser 0.16
einops 0.8.0
elasticsearch 7.9.1
evaluate 0.4.2
exceptiongroup 1.2.0
executing 2.0.1
faiss-cpu 1.8.0.post1
filelock 3.15.4
fire 0.6.0
flash-attn 2.5.9.post1
fonttools 4.53.0
frozenlist 1.4.1
fsspec 2024.5.0
huggingface-hub 0.23.4
idna 3.7
importlib_metadata 7.2.1
ipykernel 6.29.4
ipython 8.25.0
jedi 0.19.1
Jinja2 3.1.4
joblib 1.4.2
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
jupyter_client 8.6.2
jupyter_core 5.7.2
kiwisolver 1.4.5
llm2vec 0.1.7
MarkupSafe 2.1.5
matplotlib 3.9.0
matplotlib-inline 0.1.7
mistral_common 1.2.1
mistral_inference 1.1.0
mpmath 1.3.0
multidict 6.0.5
multiprocess 0.70.16
nest_asyncio 1.6.0
networkx 3.3
numpy 2.0.0
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.5.40
nvidia-nvtx-cu12 12.1.105
packaging 24.1
pandas 2.2.2
parso 0.8.4
peft 0.11.1
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.3.0
pip 24.0
platformdirs 4.2.2
prompt_toolkit 3.0.47
protobuf 5.27.1
psutil 6.0.0
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 16.1.0
pyarrow-hotfix 0.6
pydantic 2.6.1
pydantic_core 2.16.2
Pygments 2.18.0
pyparsing 3.1.2
python-dateutil 2.9.0
pytrec_eval 0.5
pytz 2024.1
PyYAML 6.0.1
pyzmq 26.0.3
referencing 0.35.1
regex 2024.5.15
requests 2.32.3
rpds-py 0.18.1
safetensors 0.4.3
scikit-learn 1.5.0
scipy 1.13.1
sentence-transformers 3.0.1
sentencepiece 0.1.99
setuptools 70.1.0
simple_parsing 0.1.5
six 1.16.0
stack-data 0.6.2
sympy 1.12.1
termcolor 2.4.0
threadpoolctl 3.5.0
tokenizers 0.19.1
torch 2.3.0
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
transformers 4.41.2
triton 2.3.0
typing_extensions 4.12.2
tzdata 2024.1
urllib3 2.2.2
wcwidth 0.2.13
wheel 0.43.0
xformers 0.0.26.post1
xxhash 3.4.1
yarl 1.9.4
zipp 3.19.2
Who can help?
No response
Information
Tasks
examples
folderReproduction
This is explained in the question above.
Expected behavior
This is explained in the question above.
The text was updated successfully, but these errors were encountered: