[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get "RuntimeError: 'weight' must be 2-D" Error when finetuning llama3-8b using ZeRO3 and customised dataset #4557

Closed
1 task done
NeWive opened this issue Jun 26, 2024 · 3 comments
Labels
solved This problem has been already solved

Comments

@NeWive
Copy link
NeWive commented Jun 26, 2024

Reminder

  • I have read the README and searched the existing issues.

System Info

  • llamafactory version: 0.8.3.dev0
  • Platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
  • Python version: 3.11.9
  • PyTorch version: 2.2.2+cu121 (GPU)
  • Transformers version: 4.41.2
  • Datasets version: 2.20.0
  • Accelerate version: 0.31.0
  • PEFT version: 0.11.1
  • TRL version: 0.9.4
  • GPU type: NVIDIA GeForce RTX 4090
  • DeepSpeed version: 0.14.4
  • Bitsandbytes version: 0.43.1

Reproduction

运行参数:

llamafactory-cli train /home/work/workspace/zyw/project/emr/llama3/LLaMA-Factory/examples/train_lora/llama3_lora_sft_ds3_emr.yaml

配置文件

llama3_lora_sft_ds3_emr.yaml:

### model
model_name_or_path: /home/work/workspace/zyw/project/emr/llama3/llama3-8B-hf

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
deepspeed: examples/deepspeed/ds_z3_config.json

### dataset
dataset: emr
template: llama3
cutoff_len: 8192
max_samples: 111
overwrite_cache: true
preprocessing_num_workers: 4

### output
output_dir: saves/llama3-8b/lora/sft9
logging_steps: 1
save_steps: 2
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
num_train_epochs: 16
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 1

ds_z3_config.json:

{
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
  }
}

stacktrace:

[INFO|trainer.py:2078] 2024-06-26 18:20:43,472 >> ***** Running training *****
[INFO|trainer.py:2079] 2024-06-26 18:20:43,472 >>   Num examples = 99
[INFO|trainer.py:2080] 2024-06-26 18:20:43,472 >>   Num Epochs = 16
[INFO|trainer.py:2081] 2024-06-26 18:20:43,472 >>   Instantaneous batch size per device = 4
[INFO|trainer.py:2084] 2024-06-26 18:20:43,473 >>   Total train batch size (w. parallel, distributed & accumulation) = 4
[INFO|trainer.py:2085] 2024-06-26 18:20:43,473 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2086] 2024-06-26 18:20:43,473 >>   Total optimization steps = 400
[INFO|trainer.py:2087] 2024-06-26 18:20:43,476 >>   Number of trainable parameters = 20,971,520
  0%|                                                                                                                                                                                                                                                                                                                   | 0/400 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/work/workspace/zyw/.conda/envs/rag/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/work/workspace/zyw/project/emr/llama3/LLaMA-Factory-for-update/LLaMA-Factory/src/llamafactory/cli.py", line 110, in main
    run_exp()
  File "/home/work/workspace/zyw/project/emr/llama3/LLaMA-Factory-for-update/LLaMA-Factory/src/llamafactory/train/tuner.py", line 50, in run_exp
    run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/home/work/workspace/zyw/project/emr/llama3/LLaMA-Factory-for-update/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 88, in run_sft
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/trainer.py", line 3238, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/trainer.py", line 3264, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/peft/peft_model.py", line 1430, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1164, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 925, in forward
    inputs_embeds = self.embed_tokens(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: 'weight' must be 2-D
  0%|                              

Expected behavior

Using a zero0-based configuration, llama3 can be finetuned successfully. But after switching to zero3-based configuration, it does not work. No changes are made to deepspeed-zero*-related config files.

Others

No response

@github-actions github-actions bot added the pending This problem is yet to be addressed label Jun 26, 2024
@hiyouga
Copy link
Owner
hiyouga commented Jun 26, 2024

It looks like you are fine-tuning model on a single card, you can remove the deepspeed config or use FORCE_TORCHRUN=1 env var

@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Jun 26, 2024
@hiyouga hiyouga closed this as completed Jun 26, 2024
@NeWive
Copy link
Author
NeWive commented Jun 26, 2024

Solved, many thanks

@ldknight
Copy link
ldknight commented Jul 3, 2024

@NeWive

Hi, can you tell me how you solved this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

3 participants