[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do I need to report the model file in bin format during the training process? #61

Closed
ScottishFold007 opened this issue Feb 9, 2023 · 2 comments

Comments

@ScottishFold007
Copy link

image
Do I need to keep the model file in bin format when training the model with peft at that time? I saved it and used it in combination with the 'lora.pt' file and found that the model generation was poor and did not make much sense.
This is my infering code:

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType, peft_model_load_and_dispatch


model_name_or_path = "/root/gaochangkuan_AI/PromptCLUE_Finetuning/model_finetuning_1_epoch"
checkpoint_name="/root/gaochangkuan_AI/PromptCLUE_Finetuning/model_finetuning_1_epoch/promptclue_lora_fsdp_v1.pt"
max_memory={0: "1GIB", 1: "1GIB", 2: "2GIB", 3: "10GIB", "cpu":"30GB"}
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=True, r=8, lora_alpha=32, lora_dropout=0.1
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, 
                                #device_map="auto", 
                                max_memory=max_memory
                                )
#model = get_peft_model(model, peft_config)
device = torch.device('cuda:7') # cuda
model.to(device)
peft_model_load_and_dispatch(model, torch.load(checkpoint_name), peft_config, max_memory)

Note:The model file in "model_finetuning_1_epoch" is saved during training, not the initial model.

So, where might the problem lie?

@pacman100
Copy link
Contributor

Hello, @ScottishFold007, could you provide the minimal training code and the training setup such as how many GPUs.
When saving a model trained using PEFT, you don't need to save the entire model, i.e., remove the below line:

- unwrapped_model.save_pretrained()

For saving and loading, please use the new cool HF hub utils from main branch:

  1. Install from main branch:
pip install git+https://github.com/huggingface/peft.git
  1. Saving PEFT model:
peft_model_id = f"/root/gaochangkuan_AI/PromptCLUE_Finetuning/model_finetuning_1_epoch/"
model.save_pretrained(peft_model_id)
  1. Loading PEFT model for inference:
from peft import PeftModel, PeftConfig
peft_model_id = f"/root/gaochangkuan_AI/PromptCLUE_Finetuning/model_finetuning_1_epoch/"

config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

Please refer this notebook for an end-to-end example: https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq.ipynb

@ScottishFold007
Copy link
Author
ScottishFold007 commented Feb 9, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants