[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update save_peft_model_callback #85

Merged
merged 6 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
update train_qlora
  • Loading branch information
jianzhnie committed Aug 7, 2023
commit e4839e6040b5cdbc5a197d81e327353530c0ccf2
3 changes: 1 addition & 2 deletions chatllms/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def train_and_evaluate(trainer: transformers.Trainer, args: argparse.Namespace,
logger.info('=' * 80)
logger.info('*** Train ***')
logger.info('=' * 80)
train_result = trainer.train(
resume_from_checkpoint=args.checkpoint_dir)
train_result = trainer.train()
metrics = train_result.metrics

# Log and save training metrics
Expand Down
75 changes: 34 additions & 41 deletions chatllms/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import bitsandbytes as bnb
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, Trainer
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList

from chatllms.data.data_utils import (DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN,
DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN)
Expand Down Expand Up @@ -221,7 +219,8 @@ def verify_dtypes(model: torch.nn.Module) -> None:
return None


def get_last_checkpoint(checkpoint_dir: str) -> Tuple[str, bool]:
def check_training_finished(args: argparse.Namespace,
logger=None) -> Tuple[str, bool]:
"""
Given a directory containing previous saved checkpoints, returns the path to the last checkpoint
if available along with a boolean flag indicating whether training has already been completed.
Expand All @@ -234,30 +233,40 @@ def get_last_checkpoint(checkpoint_dir: str) -> Tuple[str, bool]:
whether training has already been completed.
"""
# Check if provided directory exists
if isdir(checkpoint_dir):

if isdir(args.output_dir) and not args.overwrite_output_dir:
last_checkpoint = find_last_checkpoint(args.output_dir)
# Check if 'completed' file exists in the directory - indicates training has completed
is_completed = exists(join(checkpoint_dir, 'completed'))
if is_completed:
return None, True # Already finished

# Find the latest checkpoint by checking all subdirectories named 'checkpoint-*'
max_step = 0
for filename in os.listdir(checkpoint_dir):
if isdir(join(checkpoint_dir,
filename)) and filename.startswith('checkpoint'):
max_step = max(max_step,
int(filename.replace('checkpoint-', '')))
if max_step == 0:
return None, is_completed # Training started, but no checkpoint found

# Return path to the latest checkpoint directory
checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}')
print(f'Found a previous checkpoint at: {checkpoint_dir}')
return checkpoint_dir, is_completed

is_completed = exists(join(args.output_dir, 'completed'))
if last_checkpoint and is_completed:
raise ValueError(
f'Detected that training was already completed! Output directory ({args.output_dir}) already exists and is not empty. '
'Use --overwrite_output_dir to overcome.')

elif last_checkpoint:
# Return path to the latest checkpoint directory
logger.info(
f'Checkpoint detected, resuming training at ({last_checkpoint}). To avoid this behavior, change '
'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
)
return last_checkpoint, is_completed
# The directory does not exist, meaning this is the first time the training is being run
return None, False
logger.info(
f'Dete{args.output_dir} do not exists or emppty or you have set {args.overwrite_output_dir}, will train from scratch'
)
return None, False # first training


def find_last_checkpoint(checkpoint_dir):
# Find the latest checkpoint by checking all subdirectories named 'checkpoint-*'
max_step = 0
last_checkpoint = None
for filename in os.listdir(checkpoint_dir):
if isdir(join(checkpoint_dir,
filename)) and filename.startswith('checkpoint'):
max_step = max(max_step, int(filename.replace('checkpoint-', '')))
if max_step > 0:
last_checkpoint = join(checkpoint_dir, f'checkpoint-{max_step}')
return last_checkpoint


def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str):
Expand All @@ -270,19 +279,3 @@ def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str):
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa


# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores


def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor
13 changes: 8 additions & 5 deletions train_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SavePeftModelCallback, load_model_tokenizer)
from chatllms.train.training import train_and_evaluate
from chatllms.utils.logger_utils import get_root_logger
from chatllms.utils.model_utils import (get_last_checkpoint,
from chatllms.utils.model_utils import (check_training_finished,
print_trainable_parameters,
verify_dtypes)

Expand All @@ -41,12 +41,15 @@ def main():
log_file = os.path.join(args.output_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level='INFO')

# Log on each process the small summary:
logger.info(
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
+
f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
)
logger.info('Training/evaluation parameters %s', args)
# Check if training was already completed.
checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
args.checkpoint_dir = checkpoint_dir
if completed_training:
logger.warning('Detected that training was already completed!')
checkpoint_dir, completed_training = check_training_finished(args, logger)

# load model and tokenizer
model, tokenizer = load_model_tokenizer(
Expand Down